diff options
author | Jianwei Xie <xiejw0217@gmail.com> | 2018-03-26 10:15:56 -0700 |
---|---|---|
committer | Jianwei Xie <xiejw0217@gmail.com> | 2018-03-26 10:15:56 -0700 |
commit | fb822667ca73e0d84d31fb5b1d03917db6b73095 (patch) | |
tree | 91369ad40fbeec8eae18a9351396257ef1de1a01 | |
parent | 90e015a863246cd022ce7257204f0d716ac2d400 (diff) | |
parent | 2b078a508b8c6c920db121f676650d7972749bd7 (diff) |
Merge commit for internal changes
176 files changed, 10170 insertions, 3021 deletions
diff --git a/SECURITY.md b/SECURITY.md index 378e776967..5ca304404d 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,7 +168,18 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. + +In addition, please include the following information along with your report: + +* Your name and affiliation (if any). +* A description the technical details of the vulnerabilities. It is very + important to let us know how we can reproduce your findings. +* An explanation who can exploit this vulnerability, and what they gain when + doing so -- write an attack scenario. This will help us evaluate your report + quickly, especially if the issue is complex. +* Whether this vulnerability public or known to third parties. If it is, please + provide details. If you believe that an existing (public) issue is security-related, please send an email to `security@tensorflow.org`. The email should include the issue ID and @@ -233,7 +244,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|-------------------|:-----------------:|--------------------|-----------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|--------------------|:-----------------:|--------------------|-----------------------------| +| Out Of Bounds Read | <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index bea5a121b3..8df7b56623 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -31,6 +31,8 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -68,6 +70,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 2402a6d044..eaeb2fd07a 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/node_def_util.h" @@ -161,29 +162,32 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { DCHECK(h); - h->Unref(); + if (h->handle) { + h->handle->Unref(); + } + delete h; } TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast<TF_DataType>(h->dtype); + return static_cast<TF_DataType>(h->handle->dtype); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dims(); } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dim_size(dim_index); } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; - status->status = h->OpDevice(&d); + status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } @@ -193,7 +197,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; if (!IsCPU(d)) { TF_SetStatus(status, TF_UNIMPLEMENTED, @@ -210,82 +214,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { } } // extern "C" -namespace { - -tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - tensorflow::Device* dstd, - TFE_TensorHandle** output) { - const tensorflow::Tensor* src = nullptr; - tensorflow::Device* srcd = nullptr; - // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept - // nullptr. - tensorflow::Device* src_opd = nullptr; - TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); - if (srcd == nullptr) srcd = ctx->context.HostCPU(); - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(*src, dstd, dstd); - return tensorflow::Status::OK(); - } - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - return tensorflow::errors::InvalidArgument( - "Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), " to device ", - DeviceName(dstd), "."); - } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); - return tensorflow::Status::OK(); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; - } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; - } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - TF_RETURN_IF_ERROR(srcd->Sync()); - tensorflow::Notification n; - tensorflow::Status status; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [&status, &n](const tensorflow::Status& s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - if (status.ok()) { - dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); - } - return status; -} -} // namespace - extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, @@ -335,12 +263,12 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; // TODO(agarwal): This call may block if h is not ready. Avoid this if // possible. - status->status = h->Device(&d); + status->status = h->handle->Device(&d); if (!status->status.ok()) return; if (!IsCPU(d)) op->device = d; } - h->Ref(); - op->inputs.push_back(h); + h->handle->Ref(); + op->inputs.push_back(h->handle); op->attrs.NumInputs(op->inputs.size()); } @@ -506,6 +434,37 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, namespace { +// TODO(apassos) move to TensorHandle +tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( + tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status) { + status->status = ctx->context.GetStatus(); + if (!status->status.ok()) { + return nullptr; + } + tensorflow::Device* dstd = ctx->context.HostCPU(); + if (device_name != nullptr && strlen(device_name) > 0) { + status->status = + ctx->context.device_mgr()->LookupDevice(device_name, &dstd); + if (!status->status.ok()) return nullptr; + } + if (ctx->context.Async()) { + // Note that `h` may not be currently ready. However execution order will + // make sure that `h` is ready before the copy is actually done. + tensorflow::CopyToDeviceNode* node = + new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context); + tensorflow::TensorHandle* output = node->dst(); + // Note that calling Add makes `node` accessible by the EagerExecutor + // thread. So further accesses need to be thread-safe. + ctx->context.ExecutorAdd(node); + return output; + } else { + tensorflow::TensorHandle* output = nullptr; + status->status = h->CopyToDevice(&ctx->context, dstd, &output); + return output; + } +} + tensorflow::Status ValidateInputTypeAndPlacement( TFE_Context* ctx, tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, @@ -518,7 +477,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; - TFE_TensorHandle* handle = op->inputs[i]; + tensorflow::TensorHandle* handle = op->inputs[i]; tensorflow::Device* handle_device = nullptr; TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = @@ -560,8 +519,9 @@ tensorflow::Status ValidateInputTypeAndPlacement( // We are only here if the policy is warn or silent copies, so we should // trigger a copy. TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - handle, ctx, expected_device->name().c_str(), s); + tensorflow::TensorHandle* copied_tensor = + TFE_TensorHandleCopyToDevice_Internal( + handle, ctx, expected_device->name().c_str(), s); tensorflow::Status status = s->status; TF_DeleteStatus(s); if (!status.ok()) { @@ -616,9 +576,10 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, tensorflow::Status Execute( TFE_Context* ctx, tensorflow::Device* device, - const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs, + const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& + op_inputs, tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, - TFE_TensorHandle** retvals, int num_retvals) { + tensorflow::TensorHandle** retvals, int num_retvals) { if (!ctx->context.SoftPlacement() && device == nullptr) { device = ctx->context.HostCPU(); } @@ -683,7 +644,7 @@ tensorflow::Status Execute( d = nullptr; } if (retvals[i] == nullptr) { - retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device); + retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device); } else { retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); } @@ -711,9 +672,10 @@ class ExecuteNode : public tensorflow::EagerNode { } TFE_Context* ctx = op->ctx; for (int i = 0; i < num_retvals; ++i) { - TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx); + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context); h->Ref(); - retvals[i] = h; + retvals[i] = new TFE_TensorHandle(h); retvals_[i] = h; } } @@ -745,54 +707,12 @@ class ExecuteNode : public tensorflow::EagerNode { private: TFE_Context* ctx_; tensorflow::Device* op_device_; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_; tensorflow::KernelAndDevice* kernel_; std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_; }; -class CopyToDeviceNode : public tensorflow::EagerNode { - public: - CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, - TFE_Context* ctx) - : tensorflow::EagerNode(ctx->context.NextId()), - src_(src), - dstd_(dstd), - ctx_(ctx), - dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) { - src_->Ref(); - dst_->Ref(); - } - - ~CopyToDeviceNode() override { - src_->Unref(); - dst_->Unref(); - } - - tensorflow::Status Run() override { - TFE_TensorHandle* temp = nullptr; - TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); - const tensorflow::Tensor* tensor = nullptr; - tensorflow::Device* device = nullptr; - tensorflow::Device* op_device = nullptr; - tensorflow::Status status = - temp->TensorAndDevice(&tensor, &device, &op_device); - // `temp` is a ready handle. So the following call should return OK. - TF_DCHECK_OK(status) << status.error_message(); - DCHECK(tensor); - dst_->SetTensorAndDevice(*tensor, device, op_device); - temp->Unref(); - return tensorflow::Status::OK(); - } - - TFE_TensorHandle* dst() { return dst_; } - - private: - TFE_TensorHandle* src_; - tensorflow::Device* dstd_; - TFE_Context* ctx_; - TFE_TensorHandle* dst_; -}; #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a @@ -917,7 +837,7 @@ const tensorflow::FunctionDef* OpToFunction( } VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - ctx->context.AddFunctionDef(fdef); + status->status = ctx->context.AddFunctionDef(fdef); if (!status->status.ok()) return nullptr; const auto ret = ctx->context.FindFunctionDef(signature->name()); DCHECK(ret != nullptr); @@ -965,7 +885,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) { // Since input param reordering may have occurred between `op` and `launch_op` // via `op_input_to_func_input`, adjust the actual inputs accordingly. launch_op->inputs = op->inputs; - for (TFE_TensorHandle* h : launch_op->inputs) { + for (tensorflow::TensorHandle* h : launch_op->inputs) { h->Ref(); } if (!op_input_to_func_input.empty()) { @@ -1140,11 +1060,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. + std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals, + nullptr); + status->status = + Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = nullptr; + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } - status->status = Execute(op->ctx, op->device, op->inputs, kernel, - maybe_stats.get(), retvals, *num_retvals); } } @@ -1152,30 +1075,12 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - status->status = ctx->context.GetStatus(); - if (!status->status.ok()) { - return nullptr; - } - tensorflow::Device* dstd = ctx->context.HostCPU(); - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - ctx->context.device_mgr()->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - if (ctx->context.Async()) { - // Note that `h` may not be currently ready. However execution order will - // make sure that `h` is ready before the copy is actually done. - CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); - TFE_TensorHandle* output = node->dst(); - // Note that calling Add makes `node` accessible by the EagerExecutor - // thread. So further accesses need to be thread-safe. - ctx->context.ExecutorAdd(node); - return output; - } else { - TFE_TensorHandle* output = nullptr; - status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); - return output; + tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal( + h->handle, ctx, device_name, status); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); } + return nullptr; } void TFE_ContextAddFunctionDef(TFE_Context* ctx, @@ -1214,7 +1119,7 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -1306,70 +1211,8 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow - -bool TFE_TensorHandle::IsReady() { - if (node_id == 0) return true; - tensorflow::mutex_lock l(ctx_mutex_); - return ctx_ == nullptr; -} - -tensorflow::Status TFE_TensorHandle::WaitReady() { - if (node_id == 0) return tensorflow::Status::OK(); - tensorflow::EagerExecutor* executor = nullptr; - { - tensorflow::mutex_lock l(ctx_mutex_); - if (ctx_ == nullptr) return tensorflow::Status::OK(); - executor = ctx_->context.Executor(); - } - return executor->WaitFor(node_id); -} - -tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *t = &tensor_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = op_device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::TensorAndDevice( - const tensorflow::Tensor** tensor, tensorflow::Device** device, - tensorflow::Device** op_device) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *tensor = &tensor_; - *device = device_; - *op_device = op_device_; - return tensorflow::Status::OK(); -} - -void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device) { - tensorflow::mutex_lock l(ctx_mutex_); - DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " - << "on non-ready handles."; - ctx_ = nullptr; - tensor_ = tensor; - device_ = device; - op_device_ = op_device; -} - TFE_Op::~TFE_Op() { - for (TFE_TensorHandle* h : inputs) { + for (tensorflow::TensorHandle* h : inputs) { h->Unref(); } } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 5b29120b40..e6d2ab75ff 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" @@ -67,84 +68,18 @@ struct TFE_Context { tensorflow::EagerContext context; }; -struct TFE_TensorHandle : public tensorflow::core::RefCounted { - public: +struct TFE_TensorHandle { TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, tensorflow::Device* op_device) - : dtype(t.dtype()), - node_id(0), - tensor_(t), - device_(d), - op_device_(op_device), - ctx_(nullptr) {} + : handle(new tensorflow::TensorHandle(t, d, op_device)) {} TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - TFE_Context* ctx) - : dtype(dtype), - node_id(node_id), - tensor_(dtype), - device_(nullptr), - op_device_(nullptr), - ctx_(ctx) { - DCHECK_GT(node_id, 0); - } - - ~TFE_TensorHandle() override {} - - tensorflow::Status Tensor(const tensorflow::Tensor** t); - - tensorflow::Status Device(tensorflow::Device** d); - - tensorflow::Status OpDevice(tensorflow::Device** d); - - tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor, - tensorflow::Device** device, - tensorflow::Device** op_device); - - // Note that this can be called at most once, and only on non-ready handles, - // and makes them ready. - void SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device); - - // dtype for the handle. It must be the same as t.dtype() once the handle is - // ready. - const tensorflow::DataType dtype; - - private: - // If the contents of the Tensor pointed to by this handle is yet to be - // computed by a EagerNode, this function will block till that compuatation is - // done and the handle is "ready". - tensorflow::Status WaitReady(); - - bool IsReady(); - - // Id for the EagerNode that will compute the value pointed to by this handle. - // If the value is 0, the handle is already ready, but not vice-versa. - const tensorflow::uint64 node_id; - - tensorflow::Tensor tensor_; - - // TODO(ashankar): device_ == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('device_' should always - // be a valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* device_; - - // Device in which the op producing this tensor was executed. Equals to - // device_ for constant tensors. - tensorflow::Device* op_device_; - - tensorflow::mutex ctx_mutex_; - - // `ctx` is only guaranteed to be set if the handle is not "ready". This is - // typically true when the handle was produced during async execution. - // `ctx` object is not owned and should outlive this handle. - TFE_Context* ctx_ GUARDED_BY(ctx_mutex_); + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} + + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} + + tensorflow::TensorHandle* handle; }; struct TFE_Op { @@ -161,7 +96,7 @@ struct TFE_Op { const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs; tensorflow::Device* device; bool use_xla = false; }; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 02356699a2..5094e5ce67 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d15ccb0c28..5ce3c45528 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -177,6 +177,22 @@ StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } +StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<GlobalData> data, + Execute(computation, arguments, execution_options, execution_profile)); + + const Shape* shape_with_output_layout = nullptr; + if (execution_options && execution_options->has_shape_with_output_layout()) { + shape_with_output_layout = &execution_options->shape_with_output_layout(); + } + return Transfer(*data, shape_with_output_layout); +} + StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; @@ -231,6 +247,41 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute( return MakeUnique<GlobalData>(stub_, response.output()); } +StatusOr<std::unique_ptr<GlobalData>> Client::Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + ExecuteGraphRequest request; + *request.mutable_computation() = computation.proto(); + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { + *request.mutable_execution_options() = *execution_options; + } + for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; + *request.add_arguments() = argument->handle(); + } + + ExecuteResponse response; + VLOG(1) << "making execute request: " << request.ShortDebugString(); + Status s = stub_->ExecuteGraph(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + // TODO(b/74197823): Get execution stats for the graph and VLOG(1) them. + } + + return MakeUnique<GlobalData>(stub_, response.output()); +} + StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice<ComputationInstance> computations) { ExecuteParallelRequest request; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index c28380b689..ec87646ebf 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" @@ -57,6 +58,21 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global + // data that was produced from the execution. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr<std::unique_ptr<GlobalData>> Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // A struct to represent a computation instance to be executed. // * If execution_options.device_handles is not empty, the computation is // executed on the devices associated with the handles by partitioning the @@ -137,6 +153,17 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and transfers the result + // to the client as a literal. Parameters are defined the same as for + // Execute() and Transfer(). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 91396f055f..30594243dc 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -265,6 +265,24 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile( updated_options)); } +StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, + const ExecutableBuildOptions& options) { + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, + local_service_->CompileExecutable( + computation, argument_layouts, updated_options)); + return WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); +} + StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index de0ed13c43..98ee7c62c9 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -123,6 +123,15 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const ExecutableBuildOptions& options); + // Build and return a LocalExecutable object. The executable is compiled using + // the given XlaComputation, argument layouts and options. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr<std::unique_ptr<LocalExecutable>> Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index b912889e26..cc5f551c9c 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -25,12 +25,25 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:lib", + ], +) + # TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], deps = [ + ":xla_computation", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -38,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 82b61d4d51..596f39b4fd 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include <numeric> #include <string> #include <utility> @@ -80,40 +81,32 @@ void XlaBuilder::NoteError(const Status& error) { } } -StatusOr<XlaComputation> XlaBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - HloComputationProto entry; - ProgramShape* program_shape = entry.mutable_program_shape(); - - entry.set_name(name_); +StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) { + TF_RET_CHECK(root_id != nullptr); + ProgramShape program_shape; // Not all instructions can be roots. Walk backwards from the last added // instruction until a valid root is found. - entry.set_root_id(-1); - for (int64 i = instructions_.size() - 1; i >= 0; i--) { + int64 index = instructions_.size() - 1; + for (; index >= 0; index--) { TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[i].opcode())); + StringToHloOpcode(instructions_[index].opcode())); if (CanBeRoot(opcode)) { - entry.set_root_id(instructions_[i].id()); - *program_shape->mutable_result() = instructions_[i].shape(); break; } } - if (entry.root_id() == -1) { + if (index < 0) { return FailedPrecondition("no root instruction was found"); } + *root_id = instructions_[index].id(); + *program_shape.mutable_result() = instructions_[index].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. const int64 param_count = parameter_numbers_.size(); for (int64 i = 0; i < param_count; i++) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); + program_shape.add_parameters(); + program_shape.add_parameter_names(); } for (const HloInstructionProto& instr : instructions_) { // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So @@ -123,10 +116,35 @@ StatusOr<XlaComputation> XlaBuilder::Build() { const int64 index = instr.parameter_number(); TF_RET_CHECK(index >= 0 && index < param_count) << "invalid parameter number: " << index; - *program_shape->mutable_parameters(index) = instr.shape(); - *program_shape->mutable_parameter_names(index) = instr.name(); + *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameter_names(index) = instr.name(); } } + return program_shape; +} + +StatusOr<ProgramShape> XlaBuilder::GetProgramShape() { + int64 root_id; + return GetProgramShape(&root_id); +} + +StatusOr<XlaComputation> XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + entry.set_name(name_); + + { + int64 root_id; + ProgramShape program_shape; + TF_ASSIGN_OR_RETURN(program_shape, GetProgramShape(&root_id)); + entry.mutable_program_shape()->Swap(&program_shape); + entry.set_root_id(root_id); + } for (auto& instruction : instructions_) { entry.add_instructions()->Swap(&instruction); @@ -149,31 +167,134 @@ StatusOr<XlaComputation> XlaBuilder::Build() { return std::move(computation); } -XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { - auto op = [&]() -> StatusOr<XlaOp> { +StatusOr<XlaOp> XlaBuilder::InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + for (int64 dim : broadcast_dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); +} + +StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand) { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + + CHECK(ShapeUtil::IsScalar(operand_shape) || + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = + ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); + + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand_shape)) { + return InDimBroadcast(broadcast_shape, operand, {}); + } + + // Do explicit broadcast for degenerate broadcast. + std::vector<int64> broadcast_dimensions; + std::vector<int64> reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand_shape.dimensions(i)); + } else { + TF_RET_CHECK(operand_shape.dimensions(i) == 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand shape: " + << operand_shape << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return InDimBroadcast(broadcast_shape, reshaped_operand, + broadcast_dimensions); +} + +XlaOp XlaBuilder::BinaryOp( + HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs_shape, - rhs_shape, broadcast_dimensions)); - return AddInstruction(std::move(instr), HloOpcode::kAdd, {lhs, rhs}); - }; - return NoteErrorOrReturn(op()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBinaryOpShape( + binop, lhs_shape, rhs_shape, broadcast_dimensions)); + + const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); + const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + + if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { + const bool should_broadcast_lhs = lhs_rank < rhs_rank; + XlaOp from = should_broadcast_lhs ? lhs : rhs; + const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; + + std::vector<int64> to_size; + for (int64 size : instr.shape().dimensions()) { + to_size.push_back(size); + } + for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); + from_dim++) { + int64 to_dim = broadcast_dimensions[from_dim]; + to_size[to_dim] = from_shape.dimensions(from_dim); + } + + const Shape& broadcasted_shape = + ShapeUtil::MakeShape(from_shape.element_type(), to_size); + TF_ASSIGN_OR_RETURN( + XlaOp broadcasted_operand, + InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); + + updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; + updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; + } + + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), updated_lhs)); + } + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), updated_rhs)); + } + + return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + }()); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { - HloInstructionProto instr; - *instr.mutable_shape() = literal.shape(); - *instr.mutable_literal() = literal.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kConstant); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kConstant); + }()); } XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice<XlaOp> operands) { - auto op = [&]() -> StatusOr<XlaOp> { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; std::vector<const Shape*> operand_shape_ptrs; std::vector<Shape> operand_shapes; @@ -196,13 +317,12 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, } return AddInstruction(std::move(instr), HloOpcode::kCall, operands); - }; - return NoteErrorOrReturn(op()); + }()); } XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - auto op = [&]() -> StatusOr<XlaOp> { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) { return InvalidArgument("parameter %lld already registered", @@ -213,12 +333,496 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, instr.set_name(name); *instr.mutable_shape() = shape; return AddInstruction(std::move(instr), HloOpcode::kParameter); - }; - return NoteErrorOrReturn(op()); + }()); +} + +XlaOp XlaBuilder::Broadcast( + const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + const Shape& shape, + ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); + + // The client-level broadcast op just appends dimensions on the left (adds + // lowest numbered dimensions). The HLO broadcast instruction is more + // flexible and can add new dimensions anywhere. The instruction's + // dimensions field maps operand dimensions to dimensions in the broadcast + // output, so to append dimensions on the left the instruction's dimensions + // should just be the n highest dimension numbers of the output shape where + // n is the number of input dimensions. + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + std::vector<int64> dimensions(operand_rank); + for (int i = 0; i < operand_rank; ++i) { + dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + } + return InDimBroadcast(shape, operand, dimensions); + }()); +} + +StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); +} + +XlaOp XlaBuilder::Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, + int64 dimension) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferReshapeShape( + operand_shape, dimensions, new_sizes)); + XlaOp transposed = IsIdentityPermutation(dimensions) + ? operand + : Transpose(operand, dimensions); + return Reshape(shape, transposed); + }()); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + std::vector<int64> dimensions(shape.dimensions_size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); + }()); +} + +XlaOp XlaBuilder::Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions) { + return UnimplementedOp(); +} + +void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice<int64> fft_length) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { + return UnimplementedOp(); +} + +void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands, + const string& channel_name, + int64 cost_estimate_ns, const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Complex( + const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conj(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Atan2( + const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Ceil(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Round(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Cos(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Sin(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Tanh(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Real(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Imag(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> permutation) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferTransposeShape(operand_shape, permutation)); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + }()); +} + +XlaOp XlaBuilder::Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Sort(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnimplementedOp(); } + +XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, + const XlaOp& max) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<XlaOp> static_operands) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape) { + return UnimplementedOp(); } -XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice<XlaOp> operands) { +XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, + const Shape& shape) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::Reduce( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return UnimplementedOp(); +} + +XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return UnimplementedOp(); +} + +void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { + UnimplementedOp(); +} + +XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { + return UnimplementedOp(); +} + +StatusOr<XlaOp> XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice<XlaOp> operands) { const int64 handle = instructions_.size(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); @@ -229,7 +833,12 @@ XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, instr.set_name(StrCat(instr.name(), ".", handle)); } for (const auto& operand : operands) { + TF_RET_CHECK(operand.builder_ != nullptr); + TF_RET_CHECK(operand.builder_ == this) + << "Do not add XlaOp from builder " << operand.builder_->name() + << " to builder " << this->name(); instr.add_operand_ids(operand.handle()); + // TODO(b/74197823): Set metadata and sharding. } instructions_.push_back(instr); @@ -246,4 +855,9 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction( return &instructions_[op.handle()]; } +XlaOp XlaBuilder::UnimplementedOp() { + NoteError(Unimplemented("Op not yet implemented")); + return {}; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index f1d10ecdb9..c19eb47165 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -24,6 +24,8 @@ limitations under the License. #include <string> #include <utility> +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -50,10 +52,11 @@ class XlaBuilder; // TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: + XlaOp() : handle_(0), builder_(nullptr) {} + StatusOr<Shape> GetShape() const; private: - XlaOp() : handle_(0), builder_(nullptr) {} XlaOp(int64 handle, XlaBuilder* builder) : handle_(handle), builder_(builder) {} @@ -64,38 +67,6 @@ class XlaOp { XlaBuilder* builder_; // Not owned. }; -// The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. -class XlaComputation { - public: - XlaComputation(const XlaComputation&) = delete; - XlaComputation& operator=(const XlaComputation&) = delete; - - XlaComputation(XlaComputation&& from) { *this = std::move(from); } - - XlaComputation& operator=(XlaComputation&& from) { - proto_ = std::move(from.proto()); - unique_id_ = from.unique_id_; - return *this; - } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - const ProgramShape& GetProgramShape() const { return proto_.program_shape(); } - - const HloModuleProto& proto() const { return proto_; } - - private: - // Creates a null Computation. - XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} - HloModuleProto* mutable_proto() { return &proto_; } - friend class XlaBuilder; - - int64 unique_id_; - HloModuleProto proto_; -}; - // A convenient interface for building up computations. // // Thread-compatible. @@ -121,14 +92,6 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } - // Enqueues an add instruction onto the computation. - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); - - // Enqueues a call instruction onto the computation. - XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice<XlaOp> operands); - // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -156,17 +119,597 @@ class XlaBuilder { // corresponding native type yet. template <typename NativeT> XlaOp ConstantR0(NativeT value); + template <typename NativeT> + XlaOp ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template <typename NativeT> + XlaOp ConstantR2( + std::initializer_list<std::initializer_list<NativeT>> values); + template <typename NativeT> + XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values, + const Layout& layout); + template <typename NativeT> + XlaOp ConstantFromArray(const Array<NativeT>& values); + template <typename NativeT> + XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values); + template <typename NativeT> + XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values); + template <typename NativeT> + XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values, + const Layout& layout); + template <typename NativeT> + XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values); - // Returns the shape of the given op. - StatusOr<Shape> GetShape(const XlaOp& op) const; + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template <typename NativeT> + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + tensorflow::gtl::ArraySlice<int64> lhs_dilation, + tensorflow::gtl::ArraySlice<int64> rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice<int64> fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice<XlaOp> operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice<XlaOp> operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); + + // Returns the sum of the operand value across all replicas. All replicas + // supply one input to the sum and all replicas receive the resulting sum. + XlaOp CrossReplicaSum(const XlaOp& operand); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice<int64> window_dimensions, + tensorflow::gtl::ArraySlice<int64> window_strides, + tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + XlaOp SqrtF32(const XlaOp& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + XlaOp SquareF32(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. + XlaOp ReciprocalF32(const XlaOp& operand); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + XlaOp Sort(const XlaOp& operand); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice<XlaOp> operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<XlaOp> static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on parameters with index greater than or equal to + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr<bool> IsConstant(const XlaOp& operand, int64 num_parameters = 0); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); // Builds the computation with the requested operations, or returns a non-ok // status. StatusOr<XlaComputation> Build(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr<Shape> GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr<ProgramShape> GetProgramShape(); + private: - XlaOp AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice<XlaOp> operands = {}); + StatusOr<XlaOp> AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice<XlaOp> operands = {}); // Notes that the error occurred by: // * storing it internally and capturing a backtrace if it's the first error @@ -182,8 +725,34 @@ class XlaBuilder { return op.ConsumeValueOrDie(); } + // Helper method that creates an empty op and notes error. + XlaOp UnimplementedOp(); + StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + + StatusOr<XlaOp> InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape for the current + // computation and fills the root_id in the pointer. + StatusOr<ProgramShape> GetProgramShape(int64* root_id); + string name_; // Name to use for the built computation. // The first error encountered while building the computation. @@ -213,6 +782,76 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) { return ConstantLiteral(*Literal::CreateR0<NativeT>(value)); } +template <typename NativeT> +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) { + return ConstantLiteral(*Literal::CreateR1<NativeT>(values)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType<NativeT>(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR2( + std::initializer_list<std::initializer_list<NativeT>> values) { + return ConstantLiteral(*Literal::CreateR2<NativeT>(values)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values, + const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) { + return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D<NativeT>& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout<NativeT>(values, layout)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) { + return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D<NativeT>& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) { + return ConstantFromArray(values); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D<NativeT>& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template <typename NativeT> +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) { + return ConstantFromArray(values); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index a400e4e78b..529287a57a 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -57,16 +57,16 @@ TEST_F(XlaBuilderTest, OnePlusTwo) { EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); } -TEST_F(XlaBuilderTest, ParamPlusConstant) { +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { XlaBuilder b(TestName()); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); b.Add(x, b.ConstantR0<float>(1.0)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(), op::Constant())); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); } -TEST_F(XlaBuilderTest, ParamPlusParam) { +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { XlaBuilder b(TestName()); const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); @@ -79,7 +79,7 @@ TEST_F(XlaBuilderTest, ParamPlusParam) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(1))); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); } TEST_F(XlaBuilderTest, XPlusX) { @@ -133,5 +133,89 @@ TEST_F(XlaBuilderTest, Call) { op::Call(op::Constant(), op::Constant()))); } +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + b.Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + builder.Add(p0, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Do not add XlaOp from builder b1 to builder main")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc new file mode 100644 index 0000000000..3681792eee --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" + +#include <utility> + +namespace xla { + +const ProgramShape& XlaComputation::GetProgramShape() const { + return proto_.program_shape(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h new file mode 100644 index 0000000000..5b89747fdd --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ + +#include <utility> + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +// +// TODO(b/74197823): Replace xla::Computation with this one. +class XlaComputation { + public: + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + const ProgramShape& GetProgramShape() const; + const HloModuleProto& proto() const { return proto_; } + + private: + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 214c2030cd..13675b7d00 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1385,8 +1385,9 @@ void Literal::EachCellAsString( } namespace { -template <typename NativeSrcT, typename NativeDestT> -std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) { +template <typename NativeSrcT, typename NativeDestT, typename ConverterType> +std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter( + const Literal& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1396,11 +1397,18 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) { int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast<NativeDestT>(src_data[i]); + dest_data[i] = converter(src_data[i]); } return result_literal; } +template <typename NativeSrcT, typename NativeDestT> +std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); }; + return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>( + src_literal, converter); +} + template <PrimitiveType primitive_src_type> std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); @@ -1492,8 +1500,16 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert( } StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape( - const Shape& dest_shape) const { + const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this, + converter); + } return Convert(dest_shape.element_type()); } std::vector<Literal> elements; diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e24f5285d9..a96a76fbb4 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -340,8 +340,14 @@ class Literal { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. StatusOr<std::unique_ptr<Literal>> ConvertToShape( - const Shape& dest_shape) const; + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d4d67872cf..da16976d06 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -623,6 +623,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 971c2935c8..f9fabd8a35 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -302,7 +302,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplication on platforms where it causes a slowdown. + // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; }; @@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { auto operand = broadcast->mutable_operand(0); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1157,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1201,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector<int64> new_dimensions; + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index f0590943be..c48196e861 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface { bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; - // Enable dot simplication on platforms where it is profitable. + // Enable dot simplification on platforms where it is profitable. bool enable_dot_strength_reduction_; - // Enable convolution simplication on platforms where it is profitable. + // Enable convolution simplification on platforms where it is profitable. bool enable_conv_simplification_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 451294ef5d..3b80a827bf 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" +using ::testing::ElementsAre; + namespace xla { namespace { @@ -2462,6 +2464,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* input_array = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4}))); + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, input_array, {1})); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root->dimensions(), ElementsAre(2)); +} + +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + // The initial dimensions go to places 0 and 2 in the 3-dim array, + // and to places 1 and 3 in the 4-dim array, + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, param0, {0, 2})); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector<int64> input_spatials; std::vector<int64> symmetric_pad_spatials; diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 7195c31d9c..c26d2feef5 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -606,8 +606,10 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN(auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape())); + TF_ASSIGN_OR_RETURN( + auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape(), + /*round_f32_to_bf16=*/true)); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 6664496ab6..c83da9eddc 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -100,7 +100,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, *user_computation)); + &execution_options, user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 33e19efc72..b4b53ae2ed 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -127,7 +127,7 @@ class Compiler { // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses - // prior to calling this method because the some HLO passes are required for + // prior to calling this method because some HLO passes are required for // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 093db020c0..0faa9e9c41 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -670,6 +670,22 @@ cc_library( ], ) +tf_cc_test( + name = "ir_emission_utils_test", + srcs = ["ir_emission_utils_test.cc"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 788217aab6..f209a69e3c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution( // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); - const Shape& kernel_shape = convolution.operand(0)->shape(); + const Shape& kernel_shape = convolution.operand(1)->shape(); if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // Make sure input and kernel has the same data type. + CHECK( + ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape) || - ShapeUtil::ElementIsComplex(kernel_shape)) { + if (ShapeUtil::ElementIsComplex(input_shape)) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc new file mode 100644 index 0000000000..215f48c4cc --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) { + const char* const hlo_string = R"( +HloModule ModuleWithConv + +ENTRY Conv { + input = f32[32,50,28,28]{3,2,1,0} parameter(0) + kernel = f32[0,32,5,5]{3,2,1,0} parameter(1) + ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel), + window={size=5x5}, + dim_labels=b01f_01io->b01f +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + + HloComputation* entry_computation = module->entry_computation(); + + HloInstruction* conv_instr = entry_computation->root_instruction(); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3b8056d505..3405277d44 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -438,12 +438,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, - length_32, 1); + ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, + acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, - length_32, 1); + ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, + program_buffer_address, + /*SrcAlign=*/1, length_32); } ir_builder_.CreateCall(release_func, @@ -2441,7 +2443,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { auto* memcpy_instruction = ir_builder_.CreateMemCpy( - target, source, element_count * primitive_type_size, element_alignment); + target, /*DstAlign=*/element_alignment, source, + /*SrcAlign=*/element_alignment, element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. @@ -2905,7 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 86e8be8461..fb28280fad 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -128,6 +128,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); @@ -135,7 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || - opcode == HloOpcode::kFft || + opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || + opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 90191221eb..13eb75a572 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -80,5 +80,39 @@ TEST_F(ParallelTaskAssignmentTest, EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_rng + ENTRY Rng { + src0 = f32[] parameter(0) + src1 = f32[] parameter(1) + ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1), + distribution=rng_uniform + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_infeed_outfeed + ENTRY InfeedOutfeed { + infeed0 = u32[12345678,2]{1,0} infeed() + ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 279edd4ba8..cd7cbbdd71 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -109,6 +109,11 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { continue; } + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { + continue; + } + // An instruction is considered to be equivalent to another only if they // share the exact same set of operands. So to find equivalent // instructions, we just search among instructions which share operand(0) @@ -118,7 +123,7 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { tensorflow::gtl::InlinedVector<HloInstruction*, 8> equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && + if (user != instruction && !user->HasSideEffect() && user->Identical(*instruction, eq_instructions, eq_computations, is_layout_sensitive_)) { equivalent_instructions.push_back(user); diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 3601a790c4..df8853f34f 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -414,8 +414,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { EXPECT_THAT(root, op::Add(rng1, rng2)); } -// TODO(b/28245743): Handle impure functions correctly in CSE. -TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { +TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. @@ -458,14 +457,16 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Add(op::Map(), op::Map())); + VLOG(3) << "before: " << module->ToString(); + HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + VLOG(3) << "after: " << module->ToString(); EXPECT_EQ(4, computation->instruction_count()); root = computation->root_instruction(); - auto operand = root->operand(0)->operand(0); - EXPECT_THAT(operand, op::Map()); - EXPECT_THAT(root, op::Add(operand, operand)); + EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 2037764dae..595c531ccf 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -237,8 +237,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = module_config.entry_computation_layout().parameter_layout(i).shape(); - TF_RET_CHECK( - ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), + parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( @@ -247,7 +247,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } const Shape& result_shape = module_config.entry_computation_layout().result_layout().shape(); - TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + TF_RET_CHECK( + ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5690a89909..499f280211 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr<Backend> execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { + +// Retrieves the parameter metadata for the given computation and parameter +// number. +// +// If the parameter number is invalid for this computation, nullopt is +// returned. When the return value has_value(), nullptr will never be +// the held value. +tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata( + const XlaComputation& computation, int parameter_number) { + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() == computation.proto().entry_computation_id()) { + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == parameter_number) { + if (!instr.has_metadata()) { + return tensorflow::gtl::nullopt; + } + return &instr.metadata(); + } + } + } + } + return tensorflow::gtl::nullopt; +} + +ExecutionOptions CreateExecutionOptions( + const ExecutableBuildOptions& build_options, + const ProgramShape* program_shape) { + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (build_options.hlo_profile().has_value()) { + execution_options.mutable_debug_options()->set_xla_hlo_profile( + *build_options.hlo_profile()); + } + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.dump_optimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_optimized_hlo_proto_to( + build_options.dump_optimized_hlo_proto_to().value()); + } + if (build_options.dump_per_pass_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_per_pass_hlo_proto_to( + build_options.dump_per_pass_hlo_proto_to().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); + } else { + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); + } + return execution_options; +} + +} // namespace + StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, @@ -118,44 +180,78 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( *build_options.result_layout(), program_shape->result())); } - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.hlo_profile().has_value()) { - execution_options.mutable_debug_options()->set_xla_hlo_profile( - *build_options.hlo_profile()); - } - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); - } - if (build_options.dump_optimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_optimized_hlo_proto_to( - build_options.dump_optimized_hlo_proto_to().value()); + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, program_shape.get()); + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config, + CreateModuleConfig(*program_shape, argument_layouts, + &execution_options, user_computation)); + + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); + + return BuildExecutable(versioned_handle, std::move(module_config), + execute_backend_.get(), executor, + build_options.device_allocator()); +} + +StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, + const ExecutableBuildOptions& build_options) { + const HloModuleProto& proto = computation.proto(); + TF_RET_CHECK(proto.has_program_shape()); + const ProgramShape& program_shape = proto.program_shape(); + + // Validate incoming layouts. + if (argument_layouts.size() != program_shape.parameters_size()) { + return InvalidArgument( + "Invalid number of arguments for computation: expected %d, got %zu.", + program_shape.parameters_size(), argument_layouts.size()); } - if (build_options.dump_per_pass_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_per_pass_hlo_proto_to( - build_options.dump_per_pass_hlo_proto_to().value()); + + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { + tensorflow::gtl::optional<const OpMetadata*> metadata = + ParameterMetadata(computation, /*parameter_number=*/i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; + return InvalidArgument( + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } } if (build_options.result_layout() != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); - } else { - *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape.result())); } + + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - *user_computation)); + CreateModuleConfig(program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(versioned_handle, std::move(module_config), + return BuildExecutable(proto, std::move(module_config), execute_backend_.get(), executor, build_options.device_allocator()); } diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 15e120685e..06567cabd6 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -50,6 +51,18 @@ class LocalService : public Service { const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, const ExecutableBuildOptions& options); + // Builds an Executable with the given XlaComputation, argument layouts and + // options. If result_layout is non-null, then the executable is compiled to + // produce a result of the given layout. If device_allocator is non-null, + // then the compiler may use it to allocate temp space on the device. The + // compiler is responsible for freeing any memory it allocates this way. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr<std::unique_ptr<Executable>> CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, + const ExecutableBuildOptions& build_options); + // Returns the device ordinal that corresponds to the given replica number. // // This returns an error if there is not a one-to-one correspondence of diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b..f15117f45c 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,6 +53,14 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } +// Returns true if `a` is a broadcast instruction to target shape `shape` and +// its operand is a scalar. +bool IsBroadcastScalarToShape(const HloInstruction* a, const Shape& shape) { + return a->opcode() == HloOpcode::kBroadcast && + ShapeUtil::SameDimensions(a->shape(), shape) && + ShapeUtil::IsScalar(a->operand(0)->shape()); +} + // Returns true iff `instruction` can change its shape simply by adjusting // metadata. bool CanTriviallyChangeShape(const HloInstruction* instruction) { @@ -88,6 +96,7 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + return false; } @@ -148,6 +157,8 @@ bool AllOperandsHaveEasyShapeChanges( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, + // or + // 3. Are broadcast with a scalar operand. for (const HloInstruction* operand : instruction->operands()) { if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { VLOG(5) << "Operand shape differs from output shape; may be " @@ -158,6 +169,12 @@ bool AllOperandsHaveEasyShapeChanges( return false; } + // Skip the rest checks if the current operand is first_reshape_operand + // itself. + if (first_reshape_operand == operand) { + continue; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " << first_reshape_operand->ToString(print_no_metadata) @@ -171,6 +188,12 @@ bool AllOperandsHaveEasyShapeChanges( continue; } + if (IsBroadcastScalarToShape(operand, first_reshape_operand->shape())) { + VLOG(5) << "Broadcast scalar to shape: " + << operand->ToString(print_no_metadata); + continue; + } + // TODO(someone): Look into supporting general ops for the operands as // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" @@ -222,6 +245,12 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: + CHECK(IsBroadcastScalarToShape(operand, first_reshape_operand->shape())); + VLOG(5) << "Changing broadcast"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index aac8638a54..4e0a0a8832 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -560,5 +560,25 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } +TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { + const string hlo_string = R"( + HloModule TransposeMulInversedTransposeModule + ENTRY TransposeMulInversedTranspose { + src0 = f32[1,20,8,32]{3,2,1,0} parameter(0) + transpose0 = f32[1,8,20,32]{3,2,1,0} transpose(src0), dimensions={0,2,1,3} + src1 = f32[] parameter(1) + broadcast0 = f32[1,8,20,32]{3,2,1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[1,8,20,32]{3,2,1,0} multiply(transpose0, broadcast0) + } + )"; + + ParseAndVerifyModule(hlo_string.c_str()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Transpose(op::Multiply())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 0becc9d8f8..1d379f0d03 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -272,7 +272,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { auto config = MakeUnique<HloModuleConfig>(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -286,8 +286,15 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { + if (user_computation == nullptr) { + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + } return InvalidParameterArgument( - *user_computation.ParameterMetadata(i).value(), + *user_computation->ParameterMetadata(i).value(), "Argument does not match shape of computation parameter %d: want %s, " "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), @@ -330,7 +337,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { std::vector<const Shape*> argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); @@ -778,7 +785,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), *user_computation)); + request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -854,6 +861,33 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } + } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} + tensorflow::Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); @@ -870,13 +904,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); + return ExecuteOneToN(arg, result); } TF_ASSIGN_OR_RETURN( @@ -894,7 +922,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), *user_computation)); + arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -935,9 +963,66 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return tensorflow::Status::OK(); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* /*arg*/, - ExecuteResponse* /*result*/) { - return Unimplemented("execute-graph is not yet implemented"); +StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr<HloModuleConfig> module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf( + "BuildExecutable on service %p with serialized module proto: %s", this, + module_proto.name().c_str()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, + HloModule::CreateFromProto(module_proto, *module_config)); + + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); + + return std::move(executable); +} + +tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute-graph request"; + + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + + // TODO(b/74197823): Handle partitioning. + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector<std::vector<const ShapedBuffer*>> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config, + CreateModuleConfig(arg->computation().program_shape(), + replicated_arguments.front(), + arg->execution_options())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr<Executable> executable, + BuildExecutable(arg->computation(), std::move(module_config), + execute_backend_.get(), + execute_backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + "result of " + arg->computation().name(), result->mutable_profile())); + + VLOG(1) << "successfully completed 'execute-graph' request"; + return tensorflow::Status::OK(); } tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, @@ -967,7 +1052,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, TF_ASSIGN_OR_RETURN( std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), *user_computation)); + arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1268,7 +1353,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config, CreateModuleConfig(program_shape, {}, execution_options, - *user_computation)); + user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 96352d9096..773f0a642d 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -115,6 +115,8 @@ class Service : public ServiceInterface { // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; @@ -258,7 +260,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); protected: friend class LocalExecutable; @@ -286,7 +288,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); // Builds an Executable for the given parameters. // @@ -299,6 +301,15 @@ class Service : public ServiceInterface { perftools::gputools::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); + // Builds an Executable for the given HLO module proto. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr<std::unique_ptr<Executable>> BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr<HloModuleConfig> module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); + // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables( @@ -346,6 +357,12 @@ class Service : public ServiceInterface { const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>& adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result); + // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index d3d55634c9..3d3e1d60f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. -// - A while loops with static trip count of 1 is replaced by its body (sans +// - A while loop with static trip count of 1 is replaced by its body (sans // loop). // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 063e312df6..8763e588c4 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7fb7919674..26022278e5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -190,6 +190,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -386,6 +387,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1372,6 +1374,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 3f6fd7c65d..ec3b46acfe 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,11 +29,11 @@ namespace { class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { - ComputationBuilder builder(client_, "ax_10"); + XlaBuilder builder("ax_10"); auto alpha = builder.ConstantR0<float>(3.1415926535); auto x = builder.ConstantR1<float>( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Mul(alpha, x); + builder.Mul(alpha, x); std::vector<float> expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -46,7 +47,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { auto x = builder.ConstantR1<float>({}); auto y = builder.ConstantR1<float>({}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector<float> expected = {}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -60,7 +61,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto y = builder.ConstantR1<float>( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a677986cd9..d9bd1ce6eb 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -96,6 +96,20 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( } StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + return client_->ExecuteAndTransfer(computation, arguments, + &execution_options); +} + +template <> +StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_output_layout) { @@ -104,6 +118,15 @@ StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } +template <> +StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { @@ -142,16 +165,18 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } +template <typename BuilderT> void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } +template <typename BuilderT> void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -249,8 +274,28 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/, + const std::function<void(const Literal& actual, + const string& error_message)>& /*verify_output*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/, + const std::function<void(const Literal& actual, + const string& error_message)>& /*verify_output*/, + const Shape* /*output_with_layout*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +template <typename BuilderT> tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in, const Shape* shape_with_layout) { std::vector<GlobalData*> arguments(arguments_passed_in.begin(), @@ -307,8 +352,9 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( return tensorflow::Status::OK(); } +template <typename BuilderT> tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector<GlobalData*> arguments(arguments_passed_in.begin(), @@ -522,33 +568,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -std::unique_ptr<GlobalData> -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr<GlobalData> -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr<Literal> converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr<GlobalData> data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - ComputationDataHandle ClientLibraryTestBase::AddParam( const Literal& argument, ComputationBuilder* builder) { ComputationDataHandle data_handle; @@ -563,4 +582,24 @@ ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, + const Shape* shape_with_layout); + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ba0319990b..01aa6c756f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -94,15 +95,22 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr<std::unique_ptr<GlobalData>> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments); + + template <typename BuilderT> StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice<GlobalData*> arguments, + BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( const Computation& computation, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const Shape* shape_with_output_layout = nullptr); + // Convenience OrDie variants of above methods. std::unique_ptr<GlobalData> ExecuteOrDie( ComputationBuilder* builder, @@ -130,12 +138,12 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error); - template <typename NativeT> - void ComputeAndCompareR1(ComputationBuilder* builder, + template <typename NativeT, typename BuilderT> + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments); - template <typename NativeT> - void ComputeAndCompareR1(ComputationBuilder* builder, + template <typename NativeT, typename BuilderT> + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error); @@ -179,22 +187,26 @@ class ClientLibraryTestBase : public ::testing::Test { // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. + template <typename BuilderT> void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_layout = nullptr); + template <typename BuilderT> void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. + template <typename BuilderT> tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, const Shape* shape_with_layout = nullptr); + template <typename BuilderT> tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -266,17 +278,19 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. + template <typename BuilderT, typename HandleT> std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + BuilderT* builder, HandleT* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. + template <typename BuilderT, typename HandleT> std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters @@ -399,6 +413,18 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const std::function<void(const Literal& actual, + const string& error_message)>& verify_output); + tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice<GlobalData*> arguments, + const std::function<void(const Literal& actual, + const string& error_message)>& verify_output, + const Shape* output_with_layout = nullptr); + // Executes the computation and calculates the expected reference value using // the HloEvaluator. Returns two literal in the order of (expected, actual). StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>> @@ -440,9 +466,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template <typename NativeT> +template <typename NativeT, typename BuilderT> void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, + BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments) { std::unique_ptr<Literal> expected_literal = Literal::CreateR1<NativeT>(expected); @@ -450,9 +476,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template <typename NativeT> +template <typename NativeT, typename BuilderT> void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected, + BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected, tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || @@ -628,6 +654,37 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2( return result; } +template <typename BuilderT, typename HandleT> +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + BuilderT* builder, + HandleT* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +template <typename BuilderT, typename HandleT> +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr<Literal> converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr<GlobalData> data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index f7b04debd4..02272d6017 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -207,9 +208,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1<float>({}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -221,10 +222,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { // Splits a vector into a matrix. XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -241,9 +242,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -255,10 +256,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { // Transposes a 2-dimensional row vector to a column vector. XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = Literal::CreateFromArray(*simple); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -272,10 +273,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { // Transposes a 2-dimensional array. XLA_TEST_P(ReshapeTest, TransposeAsReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -291,11 +292,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. // -// Transposes a 0x4 array with ComputationBuilder::Trans. +// Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -306,10 +307,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { // Transposes a 2-dimensional array with ComputationBuilder::Trans. XLA_TEST_P(ReshapeTest, Transpose4x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -327,9 +328,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index dae402204f..dcd235f876 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -13,20 +13,23 @@ load("//tensorflow:tensorflow.bzl", "py_test") filegroup( name = "all_files", srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], + include = ["**/*"], + exclude = ["**/OWNERS"], ), visibility = ["//tensorflow:__subpackages__"], ) py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", + deps = [ + "custom_export_strategy", + ":custom_loss_head", + ":estimator", + ":model", + ":trainer_hooks", + ], ) py_library( diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index f70b29c43b..8cfe4a727a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -479,6 +479,10 @@ py_test( size = "small", srcs = ["prefetching_ops_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", + "no_oss", + ], deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 94f800e8a5..d0131896a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -468,6 +468,31 @@ class BucketBySequenceLength(test.TestCase): self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) self.assertEqual(sorted(boundaries), sorted(lengths_val)) + def testTupleElements(self): + + def elements_gen(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (x, y) + + def element_length_fn(x, y): + del y + return array_ops.shape(x)[0] + + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8])) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ae10d2eb22..36591c055a 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -140,9 +140,9 @@ def bucket_by_sequence_length(element_length_func, batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) - def element_to_bucket_id(element): + def element_to_bucket_id(*args): """Return int64 id of the length bucket for this element.""" - seq_length = element_length_func(element) + seq_length = element_length_func(*args) boundaries = list(bucket_boundaries) buckets_min = [np.iinfo(np.int32).min] + boundaries diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 507ceb3585..68e0d9cb82 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -ds = distributions +tfd = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - ds.Normal, - ds.Bernoulli, - ds.Beta, - ds.Chi2, - ds.Exponential, - ds.Gamma, - ds.InverseGamma, - ds.Laplace, - ds.StudentT, - ds.Uniform, + tfd.Normal, + tfd.Bernoulli, + tfd.Beta, + tfd.Chi2, + tfd.Exponential, + tfd.Gamma, + tfd.InverseGamma, + tfd.Laplace, + tfd.StudentT, + tfd.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -63,15 +65,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], - validate_args=True) + wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -84,19 +86,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = ds.Normal(mu, sigma, validate_args=True) + normal = tfd.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) - normal = ds.Normal([mu], [sigma], validate_args=True) + normal = tfd.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) @@ -126,7 +128,7 @@ class DistributionTest(test.TestCase): self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) def _GetFakeDistribution(self): - class FakeDistribution(ds.Distribution): + class FakeDistribution(tfd.Distribution): """Fake Distribution for testing _set_sample_static_shape.""" def __init__(self, batch_shape=None, event_shape=None): @@ -188,6 +190,105 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testStrWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("tf.distributions.Normal(" + "\"Normal\", " + "batch_shape=(), " + "event_shape=(), " + "dtype=float16)"), # Got the dtype right. + str(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("tf.distributions.Chi2(" + "\"silly\", " # What a silly name that is! + "batch_shape=(2,), " + "event_shape=(), " + "dtype=float32)"), + str(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("tf.distributions.Exponential(\"Exponential\", " + # No batch shape. + "event_shape=(), " + "dtype=float32)"), + str(exp)) + + def testStrWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN\", " + "batch_shape=(2,), " + "event_shape=(2,), " + "dtype=float64)"), + str(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN2\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + + def testReprWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("<tf.distributions.Normal" + " 'Normal'" + " batch_shape=()" + " event_shape=()" + " dtype=float16>"), # Got the dtype right. + repr(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("<tf.distributions.Chi2" + " 'silly'" # What a silly name that is! + " batch_shape=(2,)" + " event_shape=()" + " dtype=float32>"), + repr(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("<tf.distributions.Exponential" + " 'Exponential'" + " batch_shape=<unknown>" + " event_shape=()" + " dtype=float32>"), + repr(exp)) + + def testReprWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("<tf.distributions.MultivariateNormalDiag" + " 'MVN'" + " batch_shape=(2,)" + " event_shape=(2,)" + " dtype=float64>"), + repr(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("<tf.distributions.MultivariateNormalDiag" + " 'MVN2'" + " batch_shape=(?,)" # Partially known. + " event_shape=(3,)" + " dtype=float32>"), + repr(mvn_dynamic)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index 2b7e199fad..b80c909023 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -32,6 +32,7 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data +layers = tf.keras.layers FLAGS = None @@ -56,15 +57,15 @@ class Discriminator(tf.keras.Model): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D( + self.conv1 = layers.Conv2D( 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) - self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) - self.conv2 = tf.layers.Conv2D( + self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = layers.Conv2D( 128, 5, data_format=data_format, activation=tf.tanh) - self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) - self.flatten = tf.layers.Flatten() - self.fc1 = tf.layers.Dense(1024, activation=tf.tanh) - self.fc2 = tf.layers.Dense(1, activation=None) + self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = layers.Flatten() + self.fc1 = layers.Dense(1024, activation=tf.tanh) + self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -112,16 +113,16 @@ class Generator(tf.keras.Model): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh) + self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = tf.layers.Conv2DTranspose( + self.conv1 = layers.Conv2DTranspose( 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = tf.layers.Conv2DTranspose( + self.conv2 = layers.Conv2DTranspose( 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ab847cb78..4e1380afb2 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -32,6 +32,8 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe +layers = tf.keras.layers + class LinearModel(tf.keras.Model): """A TensorFlow linear regression model.""" @@ -39,7 +41,7 @@ class LinearModel(tf.keras.Model): def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = tf.layers.Dense(1) + self._hidden_layer = layers.Dense(1) def call(self, xs): """Invoke the linear model. diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 6b59413141..a28bc8a43d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -28,6 +28,8 @@ import functools import tensorflow as tf +layers = tf.keras.layers + class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. @@ -49,23 +51,23 @@ class _IdentityBlock(tf.keras.Model): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = tf.layers.Conv2D( + self.conv2a = layers.Conv2D( filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) - self.bn2a = tf.layers.BatchNormalization( + self.bn2a = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2a') - self.conv2b = tf.layers.Conv2D( + self.conv2b = layers.Conv2D( filters2, kernel_size, padding='same', data_format=data_format, name=conv_name_base + '2b') - self.bn2b = tf.layers.BatchNormalization( + self.bn2b = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2b') - self.conv2c = tf.layers.Conv2D( + self.conv2c = layers.Conv2D( filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) - self.bn2c = tf.layers.BatchNormalization( + self.bn2c = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): @@ -113,34 +115,34 @@ class _ConvBlock(tf.keras.Model): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = tf.layers.Conv2D( + self.conv2a = layers.Conv2D( filters1, (1, 1), strides=strides, name=conv_name_base + '2a', data_format=data_format) - self.bn2a = tf.layers.BatchNormalization( + self.bn2a = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2a') - self.conv2b = tf.layers.Conv2D( + self.conv2b = layers.Conv2D( filters2, kernel_size, padding='same', name=conv_name_base + '2b', data_format=data_format) - self.bn2b = tf.layers.BatchNormalization( + self.bn2b = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2b') - self.conv2c = tf.layers.Conv2D( + self.conv2c = layers.Conv2D( filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) - self.bn2c = tf.layers.BatchNormalization( + self.bn2c = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '2c') - self.conv_shortcut = tf.layers.Conv2D( + self.conv_shortcut = layers.Conv2D( filters3, (1, 1), strides=strides, name=conv_name_base + '1', data_format=data_format) - self.bn_shortcut = tf.layers.BatchNormalization( + self.bn_shortcut = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): @@ -219,15 +221,15 @@ class ResNet50(tf.keras.Model): return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - self.conv1 = tf.layers.Conv2D( + self.conv1 = layers.Conv2D( 64, (7, 7), strides=(2, 2), data_format=data_format, padding='same', name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1') - self.max_pool = tf.layers.MaxPooling2D( + self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = layers.MaxPooling2D( (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) @@ -250,11 +252,12 @@ class ResNet50(tf.keras.Model): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = tf.layers.AveragePooling2D( + self.avg_pool = layers.AveragePooling2D( (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = tf.layers.Dense(classes, name='fc1000') + self.flatten = layers.Flatten() + self.fc1000 = layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -298,7 +301,7 @@ class ResNet50(tf.keras.Model): x = self.avg_pool(x) if self.include_top: - return self.fc1000(tf.layers.flatten(x)) + return self.fc1000(self.flatten(x)) elif self.global_pooling: return self.global_pooling(x) else: diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 88fffc962f..492adbe1d8 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -73,6 +73,8 @@ try: except ImportError: HAS_MATPLOTLIB = False +layers = tf.keras.layers + def parse(line): """Parse a line from the colors dataset.""" @@ -152,7 +154,7 @@ class RNNColorbot(tf.keras.Model): self.cells = self._add_cells( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) - self.relu = tf.layers.Dense( + self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") def call(self, inputs, training=False): @@ -204,7 +206,7 @@ class RNNColorbot(tf.keras.Model): def _add_cells(self, cells): # "Magic" required for keras.Model classes to track all the variables in - # a list of tf.layers.Layer objects. + # a list of layers.Layer objects. # TODO(ashankar): Figure out API so user code doesn't have to do this. for i, c in enumerate(cells): setattr(self, "cell-%d" % i, c) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 69cd16d12c..a90048d813 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -38,6 +38,8 @@ import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.eager.python import tfe +layers = tf.keras.layers + class RNN(tf.keras.Model): """A static RNN. @@ -74,14 +76,14 @@ class RNN(tf.keras.Model): def _add_cells(self, cells): # "Magic" required for keras.Model classes to track all the variables in - # a list of tf.layers.Layer objects. + # a list of Layer objects. # TODO(ashankar): Figure out API so user code doesn't have to do this. for i, c in enumerate(cells): setattr(self, "cell-%d" % i, c) return cells -class Embedding(tf.layers.Layer): +class Embedding(layers.Layer): """An Embedding layer.""" def __init__(self, vocab_size, embedding_dim, **kwargs): @@ -132,7 +134,7 @@ class PTBModel(tf.keras.Model): else: self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) - self.linear = tf.layers.Dense( + self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) self._output_shape = [-1, embedding_dim] diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 3f9a7818a5..6673653418 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -33,6 +33,7 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -172,7 +173,7 @@ class SpinnTest(test_util.TensorFlowTestCase): right_in.append(tf.random_normal((1, size * 2))) tracking.append(tf.random_normal((1, tracker_size * 2))) - out = reducer(left_in, right_in, tracking=tracking) + out = reducer(left_in, right_in=right_in, tracking=tracking) self.assertEqual(batch_size, len(out)) self.assertEqual(tf.float32, out[0].dtype) self.assertEqual((1, size * 2), out[0].shape) @@ -226,7 +227,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual((batch_size, size * 2), stacks[0][0].shape) for _ in range(2): - out1, out2 = tracker(bufs, stacks) + out1, out2 = tracker(bufs, stacks=stacks) self.assertIsNone(out2) self.assertEqual(batch_size, len(out1)) self.assertEqual(tf.float32, out1[0].dtype) @@ -259,7 +260,7 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(tf.int64, transitions.dtype) self.assertEqual((num_transitions, 1), transitions.shape) - out = s(buffers, transitions, training=True) + out = s(buffers, transitions=transitions, training=True) self.assertEqual(tf.float32, out.dtype) self.assertEqual((1, embedding_dims), out.shape) @@ -285,12 +286,15 @@ class SpinnTest(test_util.TensorFlowTestCase): vocab_size) # Invoke model under non-training mode. - logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = model( + prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=False) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) # Invoke model under training model. - logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + logits = model(prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=True) self.assertEqual(tf.float32, logits.dtype) self.assertEqual((batch_size, d_out), logits.shape) @@ -420,8 +424,14 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - ckpt_variable_names = [ - item[0] for item in checkpoint_utils.list_variables(config.logdir)] + object_graph_string = checkpoint_utils.load_variable( + config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") + object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() + object_graph.ParseFromString(object_graph_string) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 24374266dc..c846343d6d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -358,7 +358,7 @@ cuda_py_test( size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ - "//third_party/py/absl/testing:parameterized", + "@absl_py//absl/testing:parameterized", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 42e1b7b68c..74da2cbb3f 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -304,7 +304,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -355,7 +355,8 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -404,7 +405,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): self._n_classes = n_classes diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 776f0ee341..8837dfdc6c 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -272,9 +272,9 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - expected_training_loss = np.sum( + # loss = (labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / 2 + expected_training_loss = 0.5 * np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -298,7 +298,7 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_training_loss = np.sum( + expected_training_loss = 0.5 * np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -361,7 +361,7 @@ class MultiLabelHead(test.TestCase): labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -438,12 +438,13 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -468,14 +469,13 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -533,14 +533,13 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -562,15 +561,14 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, @@ -603,8 +601,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2 + expected_loss = 12.5 spec = head.create_estimator_spec( features={ @@ -617,8 +616,8 @@ class MultiLabelHead(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { - # Average loss over weighted examples. - keys.LOSS_MEAN: expected_loss / 3, + # Average loss over weighted examples (denominator is sum(weights)). + keys.LOSS_MEAN: expected_loss * (2. / 3.), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, @@ -663,7 +662,7 @@ class MultiLabelHead(test.TestCase): # (1 - labels) * (logits > 0) * logits expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] expected_weights = [[1.], [2.]] - expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), @@ -809,11 +808,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol) def test_train(self): head = head_lib.multi_label_head(n_classes=2) @@ -823,8 +819,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -840,8 +837,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -858,8 +856,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -871,8 +870,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 expected_train_result = 'my_train_op' class _Optimizer(object): @@ -952,8 +952,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2 + expected_loss = 12.5 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -987,11 +988,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over weighted examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol) def test_multi_dim_weighted_train_create_loss(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" @@ -1008,8 +1006,8 @@ class MultiLabelHead(test.TestCase): expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] # weights are reshaped to [2, 2, 1] to match logits. expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_training_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_training_loss = 9.9167 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, @@ -1035,8 +1033,8 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -1124,11 +1122,11 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 keys = metric_keys.MetricKeys expected_metrics = { - keys.LOSS_MEAN: expected_loss / np.sum(weights), + keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 43cc157a1f..74d3d6d728 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -299,10 +299,11 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15. expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 spec = multi_head.create_estimator_spec( @@ -316,8 +317,8 @@ class MultiHeadTest(test.TestCase): keys.LOSS + '/head1': expected_loss_head1, keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. - keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, + keys.LOSS_MEAN + '/head1': expected_loss_head1, + keys.LOSS_MEAN + '/head2': expected_loss_head2, # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, @@ -363,8 +364,8 @@ class MultiHeadTest(test.TestCase): tol = 1e-3 with self.test_session(): # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] - # (averaged over classes, sum-reduced over examples). - self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) + # (averaged over classes, averaged over examples). + self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): # Use different example weighting for each head weighting. @@ -399,18 +400,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -447,18 +448,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -511,8 +512,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -546,8 +547,6 @@ class MultiHeadTest(test.TestCase): _assert_simple_summaries(self, { metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) def test_train_one_head_with_optimizer(self): @@ -560,8 +559,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' class _Optimizer(object): @@ -607,10 +606,12 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -646,9 +647,6 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, }, summary_str, tol) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 30c5404e03..f22dbcf215 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,30 +39,6 @@ from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] -class DeviceContextGeneratorTest(test.TestCase): - - def testNoDevice(self): - device_context_generator = estimator._DeviceContextGenerator(None) - with ops.device("/device:CPU:0"): # This is what will be used - with device_context_generator(): # Does nothing - a = constant_op.constant([2.0], name="a") - self.assertEqual("/device:CPU:0", a.op.device) - - def testTwoDevices(self): - device_context_generator = estimator._DeviceContextGenerator( - ["/device:GPU:0", "/device:GPU:1"]) - with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes - with device_context_generator(): - a = constant_op.constant([2.0], name="a") - with device_context_generator(): - b = constant_op.constant([2.0], name="b") - with device_context_generator(): - c = constant_op.constant([2.0], name="c") - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual("/device:GPU:1", b.op.device) - self.assertEqual("/device:GPU:0", c.op.device) - - class EstimatorTest(test.TestCase): def setUp(self): @@ -90,68 +65,98 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - est = estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) est.make_ops_and_vars() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) est.make_ops_and_vars() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="not_a_real_mode") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") est.make_ops_and_vars() def testGradientsModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="gradients") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") est.make_ops_and_vars() def testEmpiricalModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="empirical") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") est.make_ops_and_vars() def testCurvaturePropModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="curvature_prop") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") est.make_ops_and_vars() def testExactModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, - estimation_mode="exact") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") est.make_ops_and_vars() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, @@ -159,8 +164,8 @@ class EstimatorTest(test.TestCase): # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, cov_update_op_thunks, - _, _) = fisher_estimator.create_ops_and_vars_thunks() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ @@ -198,10 +203,43 @@ class EstimatorTest(test.TestCase): sess.run(cov_update_op) sess.run(increment_global_step) + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_ops, _, inv_update_ops, _, _, + _) = fisher_estimator.make_ops_and_vars(scope="test") + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index c26230c2a8..d721ad08af 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -171,6 +171,7 @@ py_library( name = "fisher_estimator", srcs = [ "estimator.py", + "placement.py", ], srcs_version = "PY2AND3", deps = [ @@ -180,6 +181,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 64755be65c..ced1110676 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import itertools - +import abc import numpy as np +import six +from tensorflow.contrib.kfac.python.ops import placement from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops @@ -31,63 +31,46 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -class _DeviceContextGenerator(object): - """Class for generating device contexts in a round-robin fashion.""" - - def __init__(self, devices): - """Creates a _DeviceContextGenerator object. - - Example usage: +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. - ```python - dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) - with dcg(): - # All operations in this context will be placed on GPU 0 - ... - with dcg(): - # All operations in this context will be placed on GPU 1 - ... - ``` - - Args: - devices: An iterable of device strings (or None). Successive calls to - __call__ will give contexts which place devices on these devices in - a round-robin fashion. - """ - self._cycle = None if devices is None else itertools.cycle(devices) + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. - @contextlib.contextmanager - def __call__(self): - """Returns a context manager specifying the default device.""" - if self._cycle is None: - yield - else: - with tf_ops.device(next(self._cycle)): - yield + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. -def _make_thunk_on_device(func, device): - def thunk(): - with tf_ops.device(device): - return func() - return thunk + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops placement strategy : %s", + placement_strategy) +# pylint: enable=abstract-class-instantiated +@six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. - Attributes: - cov_update_thunks: list of no-arg functions. Executing a function adds - covariance update ops for a single FisherFactor to the graph. - cov_update_ops: List of Ops. Running an op updates covariance matrices for a - single FisherFactor. - cov_update_op: Op. Running updates covariance matrices for all - FisherFactors. - inv_update_thunks: list of no-arg functions. Executing a function adds - inverse update ops for a single FisherFactor to the graph. - inv_update_ops: List of Ops. Running an op updates inverse matrices for a - single FisherFactor. - inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. """ def __init__(self, @@ -184,6 +167,77 @@ class FisherEstimator(object): def name(self): return self._name + @abc.abstractmethod + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. For example in case of + round robin placement a new device is chosen for each factor by cycling + through list of devices in the cov_devices argument. If cov_devices is None + then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + pass + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. @@ -286,158 +340,6 @@ class FisherEstimator(object): self._instantiate_factors() self._register_matrix_functions() - def make_ops_and_vars(self, scope=None): - """Make ops and vars with no specific device placement. - - See make_ops_and_vars_round_robin for further details. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - return self.make_ops_and_vars_round_robin(scope=scope) - - # TODO(b/70674513): Factor device placement outside of this class. - def make_ops_and_vars_round_robin(self, scope=None, cov_devices=None, - inv_devices=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - (cov_update_thunks, - inv_update_thunks) = self.make_vars_and_create_op_thunks_round_robin( - scope=scope, - cov_devices=cov_devices, - inv_devices=inv_devices) - cov_update_ops = [thunk() for thunk in cov_update_thunks] - inv_update_ops = [thunk() for thunk in inv_update_thunks] - - scope = self.name if scope is None else scope - with variable_scope.variable_scope(scope): - cov_update_op = control_flow_ops.group(cov_update_ops, - name="cov_update_op") - inv_update_op = control_flow_ops.group(inv_update_ops, - name="inv_update_op") - - return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, - cov_update_thunks, inv_update_thunks) - - def make_vars_and_create_op_thunks_round_robin(self, - scope=None, - cov_devices=None, - inv_devices=None): - """Make vars and create op thunks w/ a round-robin device placement strat. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - - (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, - inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) - - if cov_devices: - cov_update_thunks = [] - for cov_variable_thunk, cov_update_thunk, device in zip( - cov_variable_thunks_raw, cov_update_thunks_raw, - itertools.cycle(cov_devices)): - with tf_ops.device(device): - cov_variable_thunk() - cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, - device)) - else: - for cov_variable_thunk in cov_variable_thunks_raw: - cov_variable_thunk() - cov_update_thunks = cov_update_thunks_raw - - for inv_variable_thunk in inv_variable_thunks_raw: - inv_variable_thunk() - - if inv_devices: - inv_update_thunks = [] - for inv_update_thunk, device in zip(inv_update_thunks_raw, - itertools.cycle(inv_devices)): - inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, - device)) - else: - inv_update_thunks = inv_update_thunks_raw - - return cov_update_thunks, inv_update_thunks - def create_ops_and_vars_thunks(self, scope=None): """Create thunks that make the ops and vars on demand. @@ -582,3 +484,9 @@ class FisherEstimator(object): colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 083da768ec..843aeef7d8 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import warnings - # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -53,8 +52,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): estimation_mode="gradients", colocate_gradients_with_ops=True, batch_size=None, - cov_devices=None, - inv_devices=None): + placement_strategy=None, + **kwargs): """Initializes the KFAC optimizer with the given settings. Args: @@ -96,14 +95,11 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): (Default: True) batch_size: The size of the mini-batch. Only needed when momentum_type == 'qmodel' or when automatic adjustment is used. (Default: None) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. Only used - with (soon-to-be-depcrecated "convenience" properties). - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. Only used - with (soon-to-be-depcrecated "convenience" properties). + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passesd to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. @@ -123,8 +119,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._layers = layer_collection self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops - self._cov_devices = cov_devices - self._inv_devices = inv_devices # The below paramaters are required only if damping needs to be adapated. # These parameters can be set by calling @@ -164,16 +158,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._momentum_type = momentum_type self._norm_constraint = norm_constraint self._batch_size = batch_size + self._placement_strategy = placement_strategy with variable_scope.variable_scope(name): - self._fisher_est = est.FisherEstimator( - self._variables, - self._cov_ema_decay, - self.damping, - self._layers, + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, exps=(-1,), estimation_mode=self._estimation_mode, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) @@ -237,6 +234,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): "damping", initializer=self._damping_constant, trainable=False) @property + def variables(self): + return self._variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + + @property def cov_update_thunks(self): self._maybe_make_and_save_everything() return self._cov_update_thunks @@ -266,37 +278,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._maybe_make_and_save_everything() return self._inv_update_op - @property - def variables(self): - return self._variables - - @property - def damping(self): - if self._damping: - return self._damping - else: - return self._damping_constant - - @property - def damping_adaptation_interval(self): - return self._damping_adaptation_interval - def _maybe_make_and_save_everything(self): if not self._fisher_est.made_vars(): warnings.warn("These convenience properties will be depcrecated soon. " "Please use explicit op/thunk creation methods instead " - "(e.g. make_ops_and_vars_round_robin, etc).", + "(e.g. make_ops_and_vars, etc).", DeprecationWarning) (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, self._inv_update_op, self._cov_update_thunks, - self._inv_update_thunks) = self.make_ops_and_vars_round_robin( - cov_devices=self._cov_devices, - inv_devices=self._inv_devices) + self._inv_update_thunks) = self.make_ops_and_vars() def make_ops_and_vars(self): - """Make ops and vars with no specific device placement. + """Make ops and vars with device placement `self._placement_strategy`. - See make_ops_and_vars_round_robin for details. + See `FisherEstimator.make_ops_and_vars` for details. Returns: cov_update_ops: List of ops that compute the cov updates. Corresponds @@ -307,88 +302,21 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): cov_update_op: cov_update_ops grouped into a single op. inv_update_op: inv_update_ops grouped into a single op. """ - with variable_scope.variable_scope(self.get_name()): - return self._fisher_est.make_ops_and_vars() - - def make_ops_and_vars_round_robin(self, cov_devices=None, inv_devices=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. + return self._fisher_est.make_ops_and_vars(scope=self.get_name()) - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - with variable_scope.variable_scope(self.get_name()): - return self._fisher_est.make_ops_and_vars_round_robin( - cov_devices=cov_devices, inv_devices=inv_devices) - - def make_vars_and_create_op_thunks_round_robin(self, - cov_devices=None, - inv_devices=None): - """Make vars and create op thunks w/ a round-robin device placement strat. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - Returns: cov_update_thunks: List of cov update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. inv_update_thunks: List of inv update thunks. Corresponds one-to-one with the list of factors given by the "factors" property. """ scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.make_vars_and_create_op_thunks_round_robin( - scope=scope, cov_devices=cov_devices, inv_devices=inv_devices) + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) - def ops_and_vars_thunks(self): + def create_ops_and_vars_thunks(self): """Create thunks that make the ops and vars on demand. This function returns 4 lists of thunks: cov_variable_thunks, @@ -413,7 +341,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): inv_update_thunks: A list of thunks that make the inv update ops. """ scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.ops_and_vars_thunks(scope=scope) + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def minimize(self, *args, **kwargs): # Should this variable scope encompass everything below? Or will the super- @@ -462,7 +390,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): An `Operation` that applies the specified gradients. """ self._maybe_make_and_save_everything() - # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. @@ -618,7 +545,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # compute the matrix-vector products with the transposed Fisher factor fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( self._batch_size, dtype=fft_precon_grads[0].dtype) @@ -802,7 +728,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] - # TODO(b/73448937): Move all update damping code to a separate class/function. def _update_damping(self, prev_batch, global_step): """Adapts damping parameter. Check KFAC (Section 6.5) for the details. diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000..bf12dbaa9a --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + *args: + **kwargs: + + """ + super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a round-robin device placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `None` then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + (cov_update_thunks, + inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) + cov_update_ops = [thunk() for thunk in cov_update_thunks] + inv_update_ops = [thunk() for thunk in inv_update_thunks] + + scope = self.name if scope is None else scope + with variable_scope.variable_scope(scope): + cov_update_op = control_flow_ops.group(cov_update_ops, + name="cov_update_op") + inv_update_op = control_flow_ops.group(inv_update_ops, + name="inv_update_op") + + return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, + cov_update_thunks, inv_update_thunks) + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement strat. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java new file mode 100644 index 0000000000..d0102883e6 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java @@ -0,0 +1,197 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; + +/** + * Class that benchmarks image classifier models. + * + * <p>===================== General workflow ======================= + * + * <pre>{@code + * benchmarker = new OvicBenchmarker(); + * benchmarker.getReadyToTest(labelInputStream, model); + * while (!benchmarker.shouldStop()) { + * Bitmap bitmap = ... + * benchmarker.doTestIteration(bitmap); + * } + * }</pre> + */ +public class OvicBenchmarker { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarker"; + + /** Evaluation transformation parameters. */ + private static final float CENTRAL_FRACTION = 0.875f; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + private static final int DIM_PIXEL_SIZE = 3; + private int imgHeight = 224; + private int imgWidth = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = null; + + /** A ByteBuffer to hold image data, to be feed into classifier as inputs. */ + private ByteBuffer imgData = null; + + private OvicClassifier classifier; + + /** Total runtime in ms. */ + private double totalRuntime = 0.0; + /** Total allowed runtime in ms. */ + private double wallTime = 20000 * 30.0; + + private Boolean benchmarkStarted = null; + + /** + * Initializes an {@link OvicBenchmarker} + * + * @param wallTime: a double number specifying the total amount of time to benchmark. + */ + public OvicBenchmarker(double wallTime) { + benchmarkStarted = false; + totalRuntime = 0.0; + this.wallTime = wallTime; + } + + /** Check whether the benchmarker should stop. */ + public Boolean shouldStop() { + if (totalRuntime >= wallTime) { + Log.e( + TAG, + "Total runtime " + + Double.toString(totalRuntime) + + " exceeded walltime " + + Double.toString(wallTime)); + return true; + } + return false; + } + + /** Check whether the benchmarker is ready to start classifying images. */ + public Boolean readyToTest() { + return (classifier != null); + } + + /** + * Getting the benchmarker ready for classifying images. + * + * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be + * read from. + * @param model: a {@link MappedByteBuffer} model to benchmark. + */ + public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) { + try { + Log.i(TAG, "Creating classifier."); + classifier = new OvicClassifier(labelInputStream, model); + int [] inputDims = classifier.getInputDims(); + imgHeight = inputDims[1]; + imgWidth = inputDims[2]; + // Only accept QUANTIZED_UINT8 input. + imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + intValues = new int[imgHeight * imgWidth]; + } catch (Exception e) { + Log.e(TAG, e.getMessage()); + Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker."); + } + } + + /** Return how many classes are predicted per image. */ + public int getNumPredictions() { + return classifier.getNumPredictions(); + } + + /** + * Perform test on a single bitmap image. + * + * @param bitmap: a {@link Bitmap} image to classify. + */ + public OvicSingleImageResult doTestIteration(Bitmap bitmap) + throws IOException, InterruptedException { + if (shouldStop() || !readyToTest()) { + return null; + } + OvicSingleImageResult iterResult = null; + try { + Log.i(TAG, "Converting bitmap."); + convertBitmapToInput(bitmap); + Log.i(TAG, "Classifying image."); + iterResult = classifier.classifyByteBuffer(imgData); + } catch (RuntimeException e) { + Log.e(TAG, e.getMessage()); + Log.e(TAG, "Failed to classify image."); + } + if (iterResult == null || iterResult.latency == null) { + throw new RuntimeException("Classification result or timing is invalid."); + } + Log.d(TAG, "Native inference latency: " + iterResult.latency); + Log.i(TAG, iterResult.toString()); + + if (!benchmarkStarted) { // Skip the first image to discount warming-up time. + benchmarkStarted = true; + } else { + totalRuntime += (double) iterResult.latency; + } + return iterResult; + } + + /** + * Writes Image data into a {@link ByteBuffer}. + * + * @param bitmap: a {@link Bitmap} source image. + */ + private void convertBitmapToInput(Bitmap bitmap) throws RuntimeException { + if (imgData == null) { + throw new RuntimeException("Benchmarker is not yet ready to test."); + } + imgData.rewind(); + // Perform transformations corresponding to evaluation mode. + float width = (float) bitmap.getWidth(); + float height = (float) bitmap.getHeight(); + int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2); + int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2); + int newWidth = Math.round(width - stWidth * 2); + int newHeight = Math.round(height - stHeight * 2); + bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight); + bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + + // Convert the image to ByteBuffer. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + + for (int i = 0; i < imgHeight; ++i) { + for (int j = 0; j < imgWidth; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java new file mode 100644 index 0000000000..b2dfd8f2e7 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java @@ -0,0 +1,209 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.TestHelper; + +/** Benchmark ImageNet Classifier with Tensorflow Lite. */ +public class OvicClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicClassifier"; + + /** Number of results to show (i.e. the "K" in top-K predictions). */ + private static final int RESULTS_TO_SHOW = 5; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List<String> labelList; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] inferenceOutputArray = null; + /** An array to hold final prediction probabilities. */ + private float[][] labelProbArray = null; + + /** Input resultion. */ + private int[] inputDims = null; + /** Whether the model runs as float or quantized. */ + private Boolean outputIsFloat = null; + + private PriorityQueue<Map.Entry<Integer, Float>> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator<Map.Entry<Integer, Float>>() { + @Override + public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code OvicClassifier}. */ + OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) + throws IOException, RuntimeException { + if (model == null) { + throw new RuntimeException("Input model is empty."); + } + labelList = loadLabelList(labelInputStream); + // OVIC uses one thread for CPU inference. + tflite = new Interpreter(model, 1); + inputDims = TestHelper.getInputDims(tflite, 0); + if (inputDims.length != 4) { + throw new RuntimeException("The model's input dimensions must be 4 (BWHC)."); + } + if (inputDims[0] != 1) { + throw new RuntimeException("The model must have a batch size of 1, got " + + inputDims[0] + " instead."); + } + if (inputDims[3] != 3) { + throw new RuntimeException("The model must have three color channels, got " + + inputDims[3] + " instead."); + } + int minSide = Math.min(inputDims[1], inputDims[2]); + int maxSide = Math.max(inputDims[1], inputDims[2]); + if (minSide <= 0 || maxSide > 1000) { + throw new RuntimeException("The model's resolution must be between (0, 1000]."); + } + String outputDataType = TestHelper.getOutputDataType(tflite, 0); + if (outputDataType.equals("float")) { + outputIsFloat = true; + } else if (outputDataType.equals("byte")) { + outputIsFloat = false; + } else { + throw new RuntimeException("Cannot process output type: " + outputDataType); + } + inferenceOutputArray = new byte[1][labelList.size()]; + labelProbArray = new float[1][labelList.size()]; + } + + /** Classifies a {@link ByteBuffer} image. */ + // @throws RuntimeException if model is uninitialized. + OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) throws RuntimeException { + if (tflite == null) { + throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); + } + if (outputIsFloat == null) { + throw new RuntimeException(TAG + ": Classifier output type has not been resolved."); + } + if (outputIsFloat) { + tflite.run(imgData, labelProbArray); + } else { + tflite.run(imgData, inferenceOutputArray); + /** Convert results to float */ + for (int i = 0; i < inferenceOutputArray[0].length; i++) { + labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f; + } + } + OvicSingleImageResult iterResult = computeTopKLabels(); + iterResult.latency = getLastNativeInferenceLatencyMilliseconds(); + return iterResult; + } + + /** Return the probability array of all classes. */ + public float[][] getlabelProbArray() { + return labelProbArray; + } + + /** Return the number of top labels predicted by the classifier. */ + public int getNumPredictions() { + return RESULTS_TO_SHOW; + } + + /** Return the four dimensions of the input image. */ + public int[] getInputDims() { + return inputDims; + } + + /* + * Get native inference latency of last image classification run. + * @throws RuntimeException if model is uninitialized. + */ + public Long getLastNativeInferenceLatencyMilliseconds() { + if (tflite == null) { + throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); + } + Long latency = tflite.getLastNativeInferenceDurationNanoseconds(); + return (latency == null) ? null : (Long) (latency / 1000000); + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private static List<String> loadLabelList(InputStream labelInputStream) throws IOException { + List<String> labelList = new ArrayList<String>(); + try (BufferedReader reader = + new BufferedReader(new InputStreamReader(labelInputStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + } + return labelList; + } + + /** Computes top-K labels. */ + private OvicSingleImageResult computeTopKLabels() { + if (labelList == null) { + throw new RuntimeException("Label file has not been loaded."); + } + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add(new AbstractMap.SimpleEntry<>(i, labelProbArray[0][i])); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + OvicSingleImageResult singleImageResult = new OvicSingleImageResult(); + if (sortedLabels.size() != RESULTS_TO_SHOW) { + throw new RuntimeException( + "Number of returned labels does not match requirement: " + + sortedLabels.size() + + " returned, but " + + RESULTS_TO_SHOW + + " required."); + } + for (int i = 0; i < RESULTS_TO_SHOW; ++i) { + Map.Entry<Integer, Float> label = sortedLabels.poll(); + // ImageNet model prediction indices are 0-based. + singleImageResult.topKIndices.add(label.getKey()); + singleImageResult.topKClasses.add(labelList.get(label.getKey())); + singleImageResult.topKProbs.add(label.getValue()); + } + // Labels with lowest probability are returned first, hence need to reverse them. + Collections.reverse(singleImageResult.topKIndices); + Collections.reverse(singleImageResult.topKClasses); + Collections.reverse(singleImageResult.topKProbs); + return singleImageResult; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java new file mode 100644 index 0000000000..4af9a65c2f --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java @@ -0,0 +1,54 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import java.util.ArrayList; + +/** Result class for inference run on a single image. */ +public class OvicSingleImageResult { + + /** Top K classes and probabilities. */ + public ArrayList<String> topKClasses; + public ArrayList<Float> topKProbs; + public ArrayList<Integer> topKIndices; + + /** Latency (ms). */ + public Long latency; + + OvicSingleImageResult() { + topKClasses = new ArrayList<>(); + topKProbs = new ArrayList<>(); + topKIndices = new ArrayList<>(); + latency = -1L; + } + + @Override + public String toString() { + String textToShow = latency + "ms"; + for (int k = 0; k < topKProbs.size(); ++k) { + textToShow += + "\nPrediction [" + + k + + "] = Class " + + Integer.toString(topKIndices.get(k)) + + " (" + + topKClasses.get(k) + + ") : " + + Float.toString(topKProbs.get(k)); + } + return textToShow; + } + +} diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java new file mode 100644 index 0000000000..4fd23a99d2 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -0,0 +1,176 @@ +/*Copyright 2018 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.ovic; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Paths; +import javax.imageio.ImageIO; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.ovic.OvicClassifier}. */ +@RunWith(JUnit4.class) +public final class OvicClassifierTest { + + private OvicClassifier classifier; + private InputStream labelsInputStream = null; + private MappedByteBuffer quantizedModel = null; + private MappedByteBuffer floatModel = null; + private MappedByteBuffer lowResModel = null; + private ByteBuffer testImage = null; + private ByteBuffer lowResTestImage = null; + private OvicSingleImageResult testResult = null; + private static final String LABELS_PATH = "testdata/labels.txt"; + private static final String QUANTIZED_MODEL_PATH = "testdata/quantized_model.lite"; + private static final String LOW_RES_MODEL_PATH = "testdata/low_res_model.lite"; + private static final String FLOAT_MODEL_PATH = "testdata/float_model.lite"; + private static final String TEST_IMAGE_PATH = "testdata/test_image_224.jpg"; + private static final String TEST_LOW_RES_IMAGE_PATH = "testdata/test_image_128.jpg"; + private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform" + + @Before + public void setUp() { + try { + File labelsfile = new File(getTestDir(LABELS_PATH)); + labelsInputStream = new FileInputStream(labelsfile); + quantizedModel = loadModelFile(getTestDir(QUANTIZED_MODEL_PATH)); + floatModel = loadModelFile(getTestDir(FLOAT_MODEL_PATH)); + lowResModel = loadModelFile(getTestDir(LOW_RES_MODEL_PATH)); + File imageFile = new File(getTestDir(TEST_IMAGE_PATH)); + BufferedImage img = ImageIO.read(imageFile); + testImage = toByteBuffer(img); + // Low res image and models. + imageFile = new File(getTestDir(TEST_LOW_RES_IMAGE_PATH)); + img = ImageIO.read(imageFile); + lowResTestImage = toByteBuffer(img); + } catch (IOException e) { + System.out.print(e.getMessage()); + } + System.out.println("Successful setup"); + } + + private static String getTestDir(String testfile) throws IOException { + return Paths.get("third_party/tensorflow/contrib/lite/java/ovic/src/", testfile).toString(); + } + + @Test + public void ovicClassifier_quantizedModelCreateSuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, quantizedModel); + assertThat(classifier != null).isTrue(); + } + + @Test + public void ovicClassifier_floatModelCreateSuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + assertThat(classifier != null).isTrue(); + } + + @Test + public void ovicClassifier_quantizedModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, quantizedModel); + testResult = classifier.classifyByteBuffer(testImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_floatModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + testResult = classifier.classifyByteBuffer(testImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_lowResModelClassifySuccess() throws Exception { + classifier = new OvicClassifier(labelsInputStream, lowResModel); + testResult = classifier.classifyByteBuffer(lowResTestImage); + assertCorrectTopK(testResult); + } + + @Test + public void ovicClassifier_latencyNotNull() throws Exception { + classifier = new OvicClassifier(labelsInputStream, floatModel); + testResult = classifier.classifyByteBuffer(testImage); + assertThat(testResult.latency != null).isTrue(); + } + + @Test + public void ovicClassifier_mismatchedInputResolutionFails() throws Exception { + classifier = new OvicClassifier(labelsInputStream, lowResModel); + int[] inputDims = classifier.getInputDims(); + assertThat((inputDims[1] == 128) && (inputDims[2] == 128)).isTrue(); + try { + testResult = classifier.classifyByteBuffer(testImage); + fail(); + } catch (RuntimeException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 49152 bytes, " + + "but found 150528 bytes."); + } + } + + private static ByteBuffer toByteBuffer(BufferedImage image) { + ByteBuffer imgData = ByteBuffer.allocateDirect( + image.getHeight() * image.getWidth() * 3); + imgData.order(ByteOrder.nativeOrder()); + for (int y = 0; y < image.getHeight(); y++) { + for (int x = 0; x < image.getWidth(); x++) { + int val = image.getRGB(x, y); + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + return imgData; + } + + private static void assertCorrectTopK(OvicSingleImageResult testResult) { + assertThat(testResult.topKClasses.size() > 0).isTrue(); + Boolean topKAccurate = false; + // Assert that the correct class is in the top K. + for (int i = 0; i < testResult.topKIndices.size(); i++) { + if (testResult.topKIndices.get(i) == TEST_IMAGE_GROUNDTRUTH) { + topKAccurate = true; + break; + } + } + System.out.println(testResult.toString()); + System.out.flush(); + assertThat(topKAccurate).isTrue(); + } + + private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException { + File modelfile = new File(modelFilePath); + FileInputStream inputStream = new FileInputStream(modelfile); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = 0L; + long declaredLength = fileChannel.size(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } +} diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt new file mode 100644 index 0000000000..fe811239d8 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index deab5a91d2..68bce19aa3 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -754,7 +754,7 @@ def make_mean_tests(zip_path): [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], "const_axis": [True, False], - "keep_dims": [True, False], + "keepdims": [True, False], }, { "input_dtype": [tf.float32, tf.int32, tf.int64], "input_shape": [[1, 224, 224, 3]], @@ -765,7 +765,7 @@ def make_mean_tests(zip_path): [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], "const_axis": [True, False], - "keep_dims": [True, False], + "keepdims": [True, False], }] def build_graph(parameters): @@ -788,7 +788,7 @@ def make_mean_tests(zip_path): input_tensors = [input_tensor, axis] out = tf.reduce_mean( - input_tensor, axis=axis, keep_dims=parameters["keep_dims"]) + input_tensor, axis=axis, keepdims=parameters["keepdims"]) return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 486ff1edcd..102740ee47 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -124,6 +124,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -168,6 +169,41 @@ cc_library( ) cc_library( + name = "toco_saved_model", + srcs = [ + "toco_saved_model.cc", + ], + hdrs = [ + "toco_saved_model.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_flags_proto_cc", + ":types_proto_cc", + "//tensorflow/cc/tools:freeze_saved_model", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "toco_saved_model_test", + srcs = ["toco_saved_model_test.cc"], + deps = [ + ":model_cmdline_flags", + ":toco_cmdline_flags", + ":toco_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( name = "graph_transformations", srcs = [ "graph_transformations/convert_expanddims_to_reshape.cc", @@ -363,6 +399,7 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", + ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 59a6115920..7b71792ff7 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -190,6 +190,7 @@ struct ParsedModelFlags { Arg<string> output_array; Arg<string> output_arrays; Arg<string> input_shapes; + Arg<int> batch_size = Arg<int>(1); Arg<float> mean_value = Arg<float>(0.f); Arg<string> mean_values; Arg<float> std_value = Arg<float>(1.f); @@ -215,9 +216,11 @@ struct ParsedModelFlags { // you want). See toco_cmdline_flags.cc for details. struct ParsedTocoFlags { Arg<string> input_file; + Arg<string> savedmodel_directory; Arg<string> output_file; - Arg<string> input_format; - Arg<string> output_format; + Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF"); + Arg<string> output_format = Arg<string>("TFLITE"); + Arg<string> savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg<float> default_ranges_min = Arg<float>(0.); Arg<float> default_ranges_max = Arg<float>(0.); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index a7a50e6fc9..b844e0b948 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1541,7 +1541,9 @@ void ConvertMeanOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { + if (HasAttr(node, "keepdims")) { + op->keep_dims = GetBoolAttr(node, "keepdims"); + } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } } diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4e2dec15a5..4264f21c76 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -72,6 +72,12 @@ bool ParseModelFlagsFromCommandLineFlags( "Shapes corresponding to --input_arrays, colon-separated. For " "many models each shape takes the form batch size, input array " "height, input array width, input array depth."), + Flag("batch_size", parsed_flags.batch_size.bind(), + parsed_flags.batch_size.default_value(), + "Batch size for the model. Replaces the first dimension of an " + "input size array if undefined. Use only with SavedModels when " + "--input_shapes flag is not specified. Always use --input_shapes " + "flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index f01ec0ec61..8041aa9e7f 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,40 +23,70 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" -#ifndef CHECK_OK -#define CHECK_OK(val) CHECK_EQ((val).ok(), true) -#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true) -#endif - namespace toco { namespace { -#define QCHECK_REQUIRE_TOCO_FLAG(arg) \ - QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg; - -void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - const TocoFlags& toco_flags) { - port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - - QCHECK_REQUIRE_TOCO_FLAG(input_file) - QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(), - port::file::Defaults())) - << "Specified input_file does not exist: " - << parsed_toco_flags.input_file.value(); - QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(), - port::file::Defaults())) +// Checks the permissions of the output file to ensure it is writeable. +void CheckOutputFilePermissions(const Arg<string>& output_file) { + QCHECK(output_file.specified()) << "Missing required flag --output_file.\n"; + QCHECK(port::file::Writable(output_file.value()).ok()) + << "Specified output_file is not writable: " << output_file.value() + << ".\n"; +} + +// Checks the permissions of the frozen model file. +void CheckFrozenModelPermissions(const Arg<string>& input_file) { + QCHECK(input_file.specified()) << "Missing required flag --input_file.\n"; + QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok()) + << "Specified input_file does not exist: " << input_file.value() << ".\n"; + QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok()) << "Specified input_file exists, but is not readable: " - << parsed_toco_flags.input_file.value(); + << input_file.value() << ".\n"; +} - QCHECK_REQUIRE_TOCO_FLAG(output_file); - QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value())) - << "parsed_toco_flags.input_file.value() output_file is not writable: " - << parsed_toco_flags.output_file.value(); +// Checks the permissions of the SavedModel directory. +void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) { + QCHECK(savedmodel_directory.specified()) + << "Missing required flag --savedmodel_directory.\n"; + QCHECK( + port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) + .ok()) + << "Specified savedmodel_directory does not exist: " + << savedmodel_directory.value() << ".\n"; +} + +// Reads the contents of the GraphDef from either the frozen graph file or the +// SavedModel directory. If it reads the SavedModel directory, it updates the +// ModelFlags and TocoFlags accordingly. +void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); + + bool has_input_file = parsed_toco_flags.input_file.specified(); + bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); + + // Ensure either input_file or savedmodel_directory flag has been set. + QCHECK_NE(has_input_file, has_savedmodel_dir) + << "Specify either input_file or savedmodel_directory flag.\n"; + + // Checks the input file permissions and reads the contents. + if (has_input_file) { + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); + } else { + CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); + GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, + model_flags, graph_def_contents); + } } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, @@ -67,21 +97,20 @@ void ToolMain(const ParsedTocoFlags& parsed_toco_flags, TocoFlags toco_flags; ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); - CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags); + string graph_def_contents; + ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags, + &model_flags, &graph_def_contents); + CheckOutputFilePermissions(parsed_toco_flags.output_file); - string input_file_contents; - CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(), - &input_file_contents, - port::file::Defaults())); std::unique_ptr<Model> model = - Import(toco_flags, model_flags, input_file_contents); + Import(toco_flags, model_flags, graph_def_contents); Transform(toco_flags, model.get()); string output_file_contents; Export(toco_flags, *model, toco_flags.allow_custom_ops(), &output_file_contents); - CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(), - output_file_contents, - port::file::Defaults())); + CHECK(port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, port::file::Defaults()) + .ok()); } } // namespace diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 0f67c2de72..cc7803dd86 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/strip.h" +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/core/platform/logging.h" @@ -38,6 +39,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "Input file (model of any supported format). For Protobuf " "formats, both text and binary are supported regardless of file " "extension."), + Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), + parsed_flags.savedmodel_directory.default_value(), + "Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -49,6 +53,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.output_format.default_value(), "Output file format. " "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), + Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), + parsed_flags.savedmodel_tagset.default_value(), + "Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags in the tag set must be " + "specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " @@ -128,47 +137,72 @@ bool ParseTocoFlagsFromCommandLineFlags( } } +namespace { + +// Defines the requirements for a given flag. kUseDefault means the default +// should be used in cases where the value isn't specified by the user. +enum class FlagRequirement { + kNone, + kMustBeSpecified, + kMustNotBeSpecified, + kUseDefault, +}; + +// Enforces the FlagRequirements are met for a given flag. +template <typename T> +void EnforceFlagRequirement(const T& flag, const string& flag_name, + FlagRequirement requirement) { + if (requirement == FlagRequirement::kMustBeSpecified) { + QCHECK(flag.specified()) << "Missing required flag " << flag_name; + } + if (requirement == FlagRequirement::kMustNotBeSpecified) { + QCHECK(!flag.specified()) + << "Given other flags, this flag should not have been specified: " + << flag_name; + } +} + +// Gets the value from the flag if specified. Returns default if the +// FlagRequirement is kUseDefault. +template <typename T> +absl::optional<T> GetFlagValue(const Arg<T>& flag, + FlagRequirement requirement) { + if (flag.specified()) return flag.value(); + if (requirement == FlagRequirement::kUseDefault) return flag.default_value(); + return absl::optional<T>(); +} + +} // namespace + void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, TocoFlags* toco_flags) { namespace port = toco::port; port::CheckInitGoogleIsDone("InitGoogle is not done yet"); - enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; - -#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ - do { \ - if (requirement == FlagRequirement::kMustBeSpecified) { \ - QCHECK(parsed_toco_flags.name.specified()) \ - << "Missing required flag: " << #name; \ - } \ - if (requirement == FlagRequirement::kMustNotBeSpecified) { \ - QCHECK(!parsed_toco_flags.name.specified()) \ - << "Given other flags, this flag should not have been specified: " \ - << #name; \ - } \ - } while (false) -#define READ_TOCO_FLAG(name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - toco_flags->set_##name(parsed_toco_flags.name.value()); \ - } \ +#define READ_TOCO_FLAG(name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + toco_flags->set_##name(flag_value.value()); \ + } \ } while (false) -#define PARSE_TOCO_FLAG(Type, name, requirement) \ - ENFORCE_FLAG_REQUIREMENT(name, requirement); \ - do { \ - if (parsed_toco_flags.name.specified()) { \ - Type x; \ - QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ - << "Unrecognized " << #Type << " value " \ - << parsed_toco_flags.name.value(); \ - toco_flags->set_##name(x); \ - } \ +#define PARSE_TOCO_FLAG(Type, name, requirement) \ + do { \ + EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ + auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ + if (flag_value.has_value()) { \ + Type x; \ + QCHECK(Type##_Parse(flag_value.value(), &x)) \ + << "Unrecognized " << #Type << " value " \ + << parsed_toco_flags.name.value(); \ + toco_flags->set_##name(x); \ + } \ } while (false) - PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); - PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); + PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault); + PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault); PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone); PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc new file mode 100644 index 0000000000..91a742b9e0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.cc @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> +#include <vector> + +#include "absl/strings/numbers.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { +namespace { + +// Loads a SavedModel from the directory specified in parsed_toco_flags. +// Returns a SavedModelBundle with the requested MetaGraphDef. +const tensorflow::SavedModelBundle* LoadSavedModel( + const ParsedTocoFlags& parsed_toco_flags) { + const string model_path = parsed_toco_flags.savedmodel_directory.value(); + QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) + << "Model is not saved in the supported SavedModel format.\n"; + + // Gets the tags identifying the MetaGraphDef from the command line arguments. + QCHECK(parsed_toco_flags.savedmodel_tagset.specified()) + << "Missing required flag --savedmodel_tagset.\n"; + const string tags_str = parsed_toco_flags.savedmodel_tagset.value(); + auto tags = absl::StrSplit(tags_str, ','); + + // Loads MetaGraphDef. + auto* bundle = new tensorflow::SavedModelBundle; + TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), + tensorflow::RunOptions(), model_path, + tags, bundle)) + << "Failed to load exported model from " << model_path + << ". Ensure the model contains the required tags '" << tags_str + << "'.\n"; + return bundle; +} + +// Returns the array name without the postfix. +// +// e.g. reduces "input:0" to "input". +string GetArrayName(const string& name) { + const std::vector<string>& names = absl::StrSplit(name, ':'); + return names[0]; +} + +// Returns the list of array names without the postfix sorted alphabetically. +std::set<string> GetSortedNames(const std::unordered_set<string>& names) { + std::vector<string> final_names; + final_names.reserve(names.size()); + for (const auto& name : names) { + final_names.push_back(GetArrayName(name)); + } + return std::set<string>(final_names.begin(), final_names.end()); +} + +// Gets the final shape after replacing the first dimension with batch size, if +// it is undefined (containing the value -1). Returns whether the shape is +// valid. +bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, + int batch_size, + tensorflow::TensorShapeProto* final_shape) { + for (int idx = 0; idx < shape.dim().size(); ++idx) { + int64 final_dim = shape.dim()[idx].size(); + if (final_dim == -1) { + if (idx > 0) return false; + final_dim = batch_size; + } + final_shape->add_dim()->set_size(final_dim); + } + return true; +} + +// Updates the input arrays in ModelFlags to contain the shape of the array. +void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, + ModelFlags* model_flags) { + // Build map of input array names to input arrays. + std::unordered_map<string, InputArray*> input_data_map; + for (auto& input : *model_flags->mutable_input_arrays()) { + input_data_map[input.name()] = &input; + } + + // Adds shapes to the input arrays if the shape is valid. + for (const tensorflow::NodeDef& node_def : graph_def.node()) { + if (input_data_map.find(node_def.name()) != input_data_map.end()) { + const auto shape_it = node_def.attr().find("shape"); + if (shape_it != node_def.attr().end()) { + tensorflow::TensorShapeProto final_shape; + bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), + batch_size, &final_shape); + + if (is_valid) { + auto* shape = input_data_map.at(node_def.name())->mutable_shape(); + QCHECK_EQ(shape->dims_size(), 0) + << "The shape for the input '" << node_def.name() + << "' was previously defined. For clarity please define inputs " + << "via --input_arrays and input_shapes flags.\n"; + for (const auto& dim : final_shape.dim()) { + shape->add_dims(dim.size()); + } + } + } + } + } + + // Checks all input arrays have a shape. + for (auto const& input : model_flags->input_arrays()) { + QCHECK(input.shape().dims_size() > 0) + << "A valid input shape was not found for input '" << input.name() + << "'. Please define via --input_arrays and --input_shapes flags.\n"; + } +} + +} // namespace + +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags) { + if (!parsed_model_flags.input_arrays.specified()) { + const std::set<string> sorted_inputs = GetSortedNames(inputs); + for (const auto& input_name : sorted_inputs) { + model_flags->add_input_arrays()->set_name(input_name); + } + } + + if (!parsed_model_flags.output_arrays.specified()) { + const std::set<string> sorted_outputs = GetSortedNames(outputs); + for (const auto& output_name : sorted_outputs) { + model_flags->add_output_arrays(GetArrayName(output_name)); + } + } + + if (!parsed_model_flags.input_shapes.specified()) { + int batch_size = parsed_model_flags.batch_size.value(); + ProcessInputShapes(graph_def, batch_size, model_flags); + } + + if (!parsed_toco_flags.inference_type.specified()) { + toco_flags->set_inference_type(IODataType::FLOAT); + } +} + +// TODO(nupurgarg): Add top level tests. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents) { + // Loads the MetaGraphDef within a SavedModelBundle. + auto bundle = LoadSavedModel(parsed_toco_flags); + + // Converts the MetaGraphDef to frozen GraphDef. + tensorflow::GraphDef frozen_graph_def; + std::unordered_set<string> inputs; + std::unordered_set<string> outputs; + TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, + &outputs)); + + // Reads the frozen GraphDef into a string. + QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) + << "Unable to generate serialized GraphDef.\n"; + + // Process inputs and outputs and metadata within GraphDef. + const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, + parsed_model_flags, toco_flags, model_flags); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h new file mode 100644 index 0000000000..7a0fabd82d --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ + +#include <string> +#include <vector> + +#include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +namespace toco { + +// Parses metadata into `toco_flags` and `model_flags`. +// +// Stores `inputs` as input_arrays and `outputs` as output_arrays in +// `model_flags`. Infers input_shapes from the GraphDef and stores it in +// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT +// and stores it in `toco_flags`. +void ParseMetaData(const tensorflow::GraphDef& graph_def, + const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags); + +// Generates a frozen graph from the SavedModel in the directory specified in +// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses +// metadata relating to the GraphDef into `toco_flags` and `model_flags`. +void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + TocoFlags* toco_flags, ModelFlags* model_flags, + string* graph_def_contents); + +} // namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc new file mode 100644 index 0000000000..5e122afe65 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc @@ -0,0 +1,274 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/toco/toco_saved_model.h" +#include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace toco { +namespace { + +using tensorflow::ops::Add; +using tensorflow::ops::Const; +using tensorflow::ops::FakeQuantWithMinMaxArgs; +using tensorflow::ops::Placeholder; + +class TocoSavedModelTest : public ::testing::Test { + protected: + // Calls functions to process cmdline arguments and calls ParseMetaData. + // ParseMetaData parses input_arrays, output_arrays, and gets metadata from + // SavedModel it is not defined in the cmdline arguments. + void ProcessGraphDefMetadata(const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs, + const tensorflow::GraphDef& graph_def) { + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); + ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); + ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, + parsed_model_flags_, &toco_flags_, &model_flags_); + } + + // Gets the GraphDef from the SavedModelBundle and processes metadata. + void ProcessSavedModelMetadata(const std::unordered_set<string>& inputs, + const std::unordered_set<string>& outputs) { + const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); + ProcessGraphDefMetadata(inputs, outputs, graph_def); + } + + // Returns a GraphDef representing a simple float model with a single input. + tensorflow::GraphDef GetFloatGraphDef(const std::vector<int64>& shape) { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::PartialTensorShape(shape))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple float model with two inputs. + tensorflow::GraphDef GetComplexFloatGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output inputA = + Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output inputB = + Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Returns a GraphDef representing a simple quantized model. + tensorflow::GraphDef GetQuantizedGraphDef() { + tensorflow::GraphDef graph_def; + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + tensorflow::Output input = + Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, + Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); + tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); + tensorflow::Output fake_quant = + FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); + tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); + + TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); + return graph_def; + } + + // Gets the values in the input_arrays flag. + std::vector<string> GetInputArrays() { + std::vector<string> actual; + for (const auto& input : model_flags_.input_arrays()) { + actual.push_back(input.name()); + } + return actual; + } + + // Gets the values in the output_arrays flag. + std::vector<string> GetOutputArrays() { + std::vector<string> actual(model_flags_.output_arrays().begin(), + model_flags_.output_arrays().end()); + return actual; + } + + // Gets the shape of the given input array. + string GetInputShape(const string& input_array) { + for (const auto& input : model_flags_.input_arrays()) { + if (input.name() == input_array) { + std::vector<string> dims; + for (int idx = 0; idx < input.shape().dims_size(); ++idx) { + dims.push_back(std::to_string(input.shape().dims(idx))); + } + return absl::StrJoin(dims, ","); + } + } + return ""; + } + + tensorflow::SavedModelBundle bundle_; + ParsedTocoFlags parsed_toco_flags_; + ParsedModelFlags parsed_model_flags_; + TocoFlags toco_flags_; + ModelFlags model_flags_; +}; + +// Tests if input_arrays, output_arrays, inference_type, and output_arrays are +// added to ModelFlags if they are not specified in cmdline arguments. +// Tests if the default batch size replaces a -1 in the first dimension. +TEST_F(TocoSavedModelTest, NoCmdLine) { + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the order of input_arrays and output_arrays is deterministic when +// they are taken from the SavedModel. +TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { + tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); + + // Note: The model does not have two outputs. However, the function does not + // need an accurate output_array list. This is only meant to test order. + ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add", "invalid"})); + EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); + EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if input_shapes is inferred when input_arrays is passed in via cmdline +// arguments. +TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { + parsed_model_flags_.input_arrays.bind()("input"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); + + ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs when input_shapes is defined without input_arrays. +TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "failed: input_shapes.size\\(\\) == " + "model_flags->input_arrays_size\\(\\)"); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"output0", "output1"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes are used when +// specified even if values exist within the GraphDef. +TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { + parsed_model_flags_.input_arrays.bind()("inputA"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1"); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); + + ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if the cmdline values of input_arrays, input_shapes, inference_type, +// and output_arrays are used when specified with an empty GraphDef. +TEST_F(TocoSavedModelTest, AllParamsCmdLine) { + parsed_model_flags_.input_arrays.bind()("inputA,inputB"); + parsed_model_flags_.output_arrays.bind()("outputA,outputB"); + parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); + parsed_toco_flags_.inference_type.bind()("FLOAT"); + + ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"outputA", "outputB"})); + EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); + EXPECT_EQ(GetInputShape("inputB"), "9,12"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Tests if a quantized graph gives the correct values assuming type is passed +// in via command line. +TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { + parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); + tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); +} + +// Tests if the provided batch size replaces a -1 in the first dimension of +// input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); + + ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); + EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"})); + EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"})); + EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); + EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); +} + +// Ensures a failure occurs if there is a -1 in a dimension aside from the first +// position of input shape. +TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { + parsed_model_flags_.batch_size.bind()(3); + tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); + + EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), + "A valid input shape was not found for input 'input'."); +} + +} // namespace +} // namespace toco diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py index e3040f09e4..f498b814bf 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/py2tf/converters/call_trees.py @@ -118,6 +118,12 @@ class CallTreeTransformer(transformer.Base): def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" + # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. + module_name = fqn[0] + for mod in self.uncompiled_modules: + if module_name.startswith(mod[0] + '.'): + return False + for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False diff --git a/tensorflow/contrib/py2tf/converters/for_loops.py b/tensorflow/contrib/py2tf/converters/for_loops.py index 4297c1cf2a..8d28b149a8 100644 --- a/tensorflow/contrib/py2tf/converters/for_loops.py +++ b/tensorflow/contrib/py2tf/converters/for_loops.py @@ -38,19 +38,19 @@ class ForLoopCanonicalizationTransformer(transformer.Base): self.generic_visit(node) body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) i_var = self.context.namer.new_symbol('i', body_scope.referenced) - n_var = self.context.namer.new_symbol('n', body_scope.referenced) - iterated_var = self.context.namer.new_symbol('iterated', - body_scope.referenced) + smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter', + body_scope.referenced) + cont_var = self.context.namer.new_symbol('cont', body_scope.referenced) # TODO(mdan): Use TensorListFromTensor(loop_iter) here. if anno.hasanno(node, 'extra_cond'): template = """ i = 0 - iterated = loop_iter - n = len(iterated) - while i < n and extra_cond: - target = iterated[i] + smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) + cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) + while cont and extra_cond: body i += 1 + cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ return templates.replace( template, @@ -58,18 +58,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base): target=node.target, body=node.body, i=i_var, - n=n_var, - iterated=iterated_var, + smart_loop_iter=smart_loop_iter_var, + cont=cont_var, extra_cond=anno.getanno(node, 'extra_cond')) else: template = """ i = 0 - iterated = loop_iter - n = len(iterated) - while i < n: - target = iterated[i] + smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter) + cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) + while cont: body i += 1 + cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter) """ repl = templates.replace( template, @@ -77,8 +77,8 @@ class ForLoopCanonicalizationTransformer(transformer.Base): target=node.target, body=node.body, i=i_var, - n=n_var, - iterated=iterated_var) + smart_loop_iter=smart_loop_iter_var, + cont=cont_var) return repl def visit_Continue(self, node): diff --git a/tensorflow/contrib/py2tf/converters/lists.py b/tensorflow/contrib/py2tf/converters/lists.py index 12ebd00062..3e62037a50 100644 --- a/tensorflow/contrib/py2tf/converters/lists.py +++ b/tensorflow/contrib/py2tf/converters/lists.py @@ -67,6 +67,9 @@ class ListTransformer(transformer.Base): node = self.generic_visit(node) if isinstance(node.value, gast.Call): call_node = node.value + + if not anno.hasanno(call_node.func, anno.Basic.QN): + return node qn = anno.getanno(call_node.func, anno.Basic.QN) if qn.qn[-1] == 'append' and (len(call_node.args) == 1): diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py index 5556a58c02..a969adbeca 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py @@ -168,6 +168,15 @@ class TypeInfoResolver(transformer.Base): anno.getanno(definition, 'element_type')) return node + def _process_tuple_assignment(self, source, t): + for i, e in enumerate(t.elts): + if isinstance(e, gast.Tuple): + self._process_tuple_assignment(source, e) + else: + self.scope.setval( + anno.getanno(e, anno.Basic.QN), + gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + def _process_variable_assignment(self, source, targets): if isinstance(source, gast.Call): func = source.func @@ -183,10 +192,9 @@ class TypeInfoResolver(transformer.Base): for t in targets: if isinstance(t, gast.Tuple): - for i, e in enumerate(t.elts): - self.scope.setval( - anno.getanno(e, anno.Basic.QN), - gast.Subscript(source, gast.Index(i), ctx=gast.Store())) + # need to recurse on the case of assigning nested tuples, + # ex. a, (b, c) = f() + self._process_tuple_assignment(source, t) elif isinstance(t, (gast.Name, gast.Attribute)): self.scope.setval(anno.getanno(t, anno.Basic.QN), source) else: diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index 0d9d5a85f0..8a8956197d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -196,6 +196,23 @@ class TypeInfoResolverTest(test.TestCase): f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_nested_assignment(self): + + def test_fn(foo): + a, (b, c) = foo + return a, b, c + + node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)}) + lhs = node.body[0].body[1].value.elts + a = lhs[0] + b = lhs[1] + c = lhs[2] + # TODO(mdan): change these once we have the live values propagating + # correctly + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD index d029289f5a..b53fbb5c18 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -35,6 +35,7 @@ py_library( deps = [ "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", + "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py index d9d8e34689..4e6003c852 100644 --- a/tensorflow/contrib/py2tf/utils/__init__.py +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin +from tensorflow.contrib.py2tf.utils.builtins import dynamic_dataset +from tensorflow.contrib.py2tf.utils.builtins import dynamic_for_cond from tensorflow.contrib.py2tf.utils.builtins import dynamic_print from tensorflow.contrib.py2tf.utils.builtins import dynamic_range from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns diff --git a/tensorflow/contrib/py2tf/utils/builtins.py b/tensorflow/contrib/py2tf/utils/builtins.py index 3cb62b55d4..251b4ed8ee 100644 --- a/tensorflow/contrib/py2tf/utils/builtins.py +++ b/tensorflow/contrib/py2tf/utils/builtins.py @@ -22,8 +22,10 @@ import six from tensorflow.contrib.py2tf.utils import py_func from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_inspect @@ -54,7 +56,6 @@ def dynamic_len(list_or_tensor): raise ValueError( 'len requires non-zero rank for tensor "%s"' % list_or_tensor) return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) @@ -97,3 +98,69 @@ def dynamic_print(*values): if all(map(is_tf_print_compatible, values)): return logging_ops.Print(1, values) return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + + +def dynamic_dataset(iterated): + """Implementartion of smart tf.data.Dataset epoch wrapping. + + The function checks if the input is a tf.data.Dataset and if so then wraps it + so that for each element it returns it also returns the current epoch the + dataset iteration is in, for two epochs. If the input is not a + tf.data.Dataset then it just returns the input. + + Args: + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + Either just the untouched input, or in the case of input being a + tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each + element it returns it also returns the current epoch the dataset iteration + is in. + """ + if not isinstance(iterated, dataset_ops.Dataset): + return iterated + + def epoch_dataset_number_helper(i): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(i).repeat(), iterated)) + + epoch_numbers = dataset_ops.Dataset.range(2) + return epoch_numbers.flat_map(epoch_dataset_number_helper) + + +def dynamic_for_cond(iteration, iterated): + """Implementartion of smart while-loop condition using dynamic dispatch. + + The function checks if it is iterating over a tf.data.Dataset or not, and in + the case it is not then it simply returns if we are still in range of the + iterated and the next element. If it is iterating over a dataset then it only + iterates for a single epoch. + + Args: + iteration: The current iteration of the loop. + iterated: The iterable or tf.data.Dataset that is being iterated over. + Returns: + A tuple of a bool that indicates whether the loop should continue, and the + next element in iterated. + """ + # TODO(znado): Clean up. + # TODO(znado): This won't work for unpacked iterates. Fix. + if isinstance(iterated, dataset_ops.Dataset): + curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next() + return math_ops.less(curr_epoch, 1), next_elem + elif tensor_util.is_tensor(iterated): + if iterated.shape.ndims > 1: + elem_shape = array_ops.shape(iterated)[1:] + else: + elem_shape = () + if iterated.shape.ndims == 0 or iterated.shape[0] == 0: + return False, array_ops.zeros(elem_shape, iterated.dtype) + return control_flow_ops.cond( + math_ops.less(iteration, dynamic_len(iterated)), + lambda: (True, iterated[iteration]), + lambda: (False, array_ops.zeros(elem_shape, iterated.dtype))) + elif hasattr(iterated, '__len__'): + if iteration < len(iterated): + return True, iterated[iteration] + return False, None + else: + raise NotImplementedError('Python iterators not yet supported.') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index aaa6f3c2c1..152f8c8c69 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -931,8 +931,7 @@ class _InputPipeline(object): # In the model-parallel case, both the host-side and device-side # computations must agree on the core on which infeed takes place. We # choose to perform infeed on logical core 0 of each replica. - with ops.device(tpu.core(0)): - values = self._infeed_queue.generate_dequeue_op() + values = self._infeed_queue.generate_dequeue_op(tpu_device=0) # The unflatten process uses the structure information recorded above. return self._inputs_structure_recorder.unflatten_features_and_labels( values) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index 42ac6eb680..604e6600c8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -23,6 +23,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding from tensorflow.python.framework import dtypes @@ -368,13 +369,20 @@ class InfeedQueue(object): policy.freeze() self._validate() - def generate_dequeue_op(self): + def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. If None, no explicit placement will be performed, and it is up + to the user to call this API from within a proper TPU device scope. + The XLA code will fail if the TPU dequeue instruction is not bound to + any device. + Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. @@ -392,8 +400,13 @@ class InfeedQueue(object): policy.get_sharded_shape(shape) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + if tpu_device is not None: + with ops.device(tpu.core(tpu_device)): + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + else: + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) def _generate_enqueue_op(self, inputs, diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 16397622ed..96eff86d8d 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -38,40 +38,60 @@ class HParamsTest(test.TestCase): self.assertFalse('bar' in hparams) def testSomeValues(self): - hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6') - self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values()) - expected_str = '[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\')]' + hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') + self.assertDictEqual( + {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'}, + hparams.values()) + expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' + '(\'d\', \'/a/b=c/d\')]') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual({ 'aaa': 12, 'b': 2.0, - 'c_c': 'relu6' + 'c_c': 'relu6', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual({ 'aaa': 12, 'b': -2.0e10, - 'c_c': 'relu4' + 'c_c': 'relu4', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') - self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': ''}, hparams.values()) + self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'}, + hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) + hparams.parse('d=/a/b/c/d,aaa=11,') + self.assertEqual(11, hparams.aaa) + self.assertEqual(2.0, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a/b/c/d', hparams.d) + hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): @@ -84,17 +104,19 @@ class HParamsTest(test.TestCase): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') - self.assertEqual(12, hparams.aaa) - self.assertEqual(2.0, hparams.b) + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) # Verifies that all hparams are restored. - self.assertEqual(12, hparams2.aaa) - self.assertEqual(2.0, hparams2.b) + self.assertEqual(10, hparams2.aaa) + self.assertEqual(1.5, hparams2.b) self.assertEqual('2.3"', hparams2.c_c) + self.assertEqual('/a=b/c/d', hparams2.d) def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 2885a9f823..1d11410332 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -354,6 +354,7 @@ cc_library( "platform/mutex.h", "platform/net.h", "platform/notification.h", + "platform/null_file_system.h", "platform/prefetch.h", "platform/profile_utils/clock_cycle_profiler.h", "platform/profile_utils/cpu_utils.h", diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt index 9e0de08267..4eb6eb4e4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt @@ -34,7 +34,7 @@ This operation computes Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt new file mode 100644 index 0000000000..47148f7b03 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterDiv" + in_arg { + name: "resource" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to add to `ref`. +END + } + summary: "Divides sparse updates into the variable referenced by `resource`." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] /= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] /= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt new file mode 100644 index 0000000000..71f06d9a43 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMax" + in_arg { + name: "resource" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to add to `ref`. +END + } + summary: "Reduces sparse updates into the variable referenced by `resource` using the `max` operation." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] = max(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt new file mode 100644 index 0000000000..08e40ee2a8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMin" + in_arg { + name: "resource" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to add to `ref`. +END + } + summary: "Reduces sparse updates into the variable referenced by `resource` using the `min` operation." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] = min(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions are combined. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt new file mode 100644 index 0000000000..5c63549d81 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterMul" + in_arg { + name: "resource" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to add to `ref`. +END + } + summary: "Multiplies sparse updates into the variable referenced by `resource`." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] *= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] *= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt new file mode 100644 index 0000000000..e71e60cbee --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "ResourceScatterSub" + in_arg { + name: "resource" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to add to `ref`. +END + } + summary: "Subtracts sparse updates from the variable referenced by `resource`." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions add. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt index 4b5201f025..9da9d09ea6 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt @@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt> diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt index 771cf0b591..8e99718c7e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions divide. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt new file mode 100644 index 0000000000..7b52dad4a1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMax" + in_arg { + name: "ref" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to reduce into `ref`. +END + } + out_arg { + name: "output_ref" + description: <<END += Same as `ref`. Returned as a convenience for operations that want +to use the updated values after the update is done. +END + } + attr { + name: "use_locking" + description: <<END +If True, the update will be protected by a lock; +otherwise the behavior is undefined, but may exhibit less contention. +END + } + summary: "Reduces sparse updates into a variable reference using the `max` operation." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] = max(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions combine. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt new file mode 100644 index 0000000000..721ac0ff35 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt @@ -0,0 +1,60 @@ +op { + graph_op_name: "ScatterMin" + in_arg { + name: "ref" + description: <<END +Should be from a `Variable` node. +END + } + in_arg { + name: "indices" + description: <<END +A tensor of indices into the first dimension of `ref`. +END + } + in_arg { + name: "updates" + description: <<END +A tensor of updated values to reduce into `ref`. +END + } + out_arg { + name: "output_ref" + description: <<END += Same as `ref`. Returned as a convenience for operations that want +to use the updated values after the update is done. +END + } + attr { + name: "use_locking" + description: <<END +If True, the update will be protected by a lock; +otherwise the behavior is undefined, but may exhibit less contention. +END + } + summary: "Reduces sparse updates into a variable reference using the `min` operation." + description: <<END +This operation computes + + # Scalar indices + ref[indices, ...] = min(ref[indices, ...], updates[...]) + + # Vector indices (for each i) + ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...]) + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions combine. + +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. + +<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> +<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt> +</div> +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt index a51f571b00..b9e293ba9e 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt @@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their contributions multiply. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt index c0d3a4a133..d12b3e68c2 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt @@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value. Duplicate entries are handled correctly: if multiple `indices` reference the same location, their (negated) contributions add. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/ScatterSub.png" alt> diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt index c44dbbd233..4804908afc 100644 --- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt @@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are duplicate entries in `indices`, the order at which the updates happen for each value is undefined. -Requires `updates.shape = indices.shape + ref.shape[1:]`. +Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt new file mode 100644 index 0000000000..56b5a46d10 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterDiv" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt new file mode 100644 index 0000000000..8119bcc6c6 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMax" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt new file mode 100644 index 0000000000..d874aef3fe --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMin" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt new file mode 100644 index 0000000000..365a37fa0d --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterMul" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt new file mode 100644 index 0000000000..72dc5bf889 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ResourceScatterSub" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 6ac9319ad1..16b61315f2 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index de10b10b7e..a619cac9a4 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -55,6 +55,49 @@ tf_cuda_library( ) tf_cuda_library( + name = "tensor_handle", + srcs = [ + "tensor_handle.cc", + ], + hdrs = [ + "tensor_handle.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":context", + ":eager_executor", + ":kernel_and_device", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], +) + +tf_cuda_library( + name = "copy_to_device_node", + hdrs = [ + "copy_to_device_node.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":context", + ":eager_executor", + ":tensor_handle", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], +) + +tf_cuda_library( name = "kernel_and_device", srcs = [ "kernel_and_device.cc", diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h new file mode 100644 index 0000000000..8a887540b0 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class CopyToDeviceNode : public EagerNode { + public: + CopyToDeviceNode(TensorHandle* src, Device* dstd, EagerContext* ctx) + : EagerNode(ctx->NextId()), + src_(src), + dstd_(dstd), + ctx_(ctx), + dst_(new TensorHandle(id, src_->dtype, ctx)) { + src_->Ref(); + dst_->Ref(); + } + + ~CopyToDeviceNode() override { + src_->Unref(); + dst_->Unref(); + } + + Status Run() override { + TensorHandle* temp = nullptr; + TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &temp)); + const Tensor* tensor = nullptr; + Device* device = nullptr; + Device* op_device = nullptr; + Status status = temp->TensorAndDevice(&tensor, &device, &op_device); + // `temp` is a ready handle. So the following call should return OK. + TF_DCHECK_OK(status) << status.error_message(); + DCHECK(tensor); + dst_->SetTensorAndDevice(*tensor, device, op_device); + temp->Unref(); + return Status::OK(); + } + + TensorHandle* dst() { return dst_; } + + private: + TensorHandle* src_; + Device* dstd_; + EagerContext* ctx_; + TensorHandle* dst_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc new file mode 100644 index 0000000000..328cd5dd5c --- /dev/null +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -0,0 +1,178 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" + +#include <algorithm> +#include <cstddef> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +bool TensorHandle::IsReady() { + if (node_id == 0) return true; + mutex_lock l(ctx_mutex_); + return ctx_ == nullptr; +} + +Status TensorHandle::WaitReady() { + if (node_id == 0) return Status::OK(); + EagerExecutor* executor = nullptr; + { + mutex_lock l(ctx_mutex_); + if (ctx_ == nullptr) return Status::OK(); + executor = ctx_->Executor(); + } + return executor->WaitFor(node_id); +} + +Status TensorHandle::Tensor(const tensorflow::Tensor** t) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *t = &tensor_; + return Status::OK(); +} + +Status TensorHandle::Device(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = device_; + return Status::OK(); +} + +Status TensorHandle::OpDevice(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = op_device_; + return Status::OK(); +} + +Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor, + tensorflow::Device** device, + tensorflow::Device** op_device) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *tensor = &tensor_; + *device = device_; + *op_device = op_device_; + return Status::OK(); +} + +void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device) { + mutex_lock l(ctx_mutex_); + DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " + << "on non-ready handles."; + ctx_ = nullptr; + tensor_ = tensor; + device_ = device; + op_device_ = op_device; +} + +Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, + TensorHandle** output) { + const tensorflow::Tensor* src = nullptr; + tensorflow::Device* srcd = nullptr; + // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept + // nullptr. + tensorflow::Device* src_opd = nullptr; + TF_RETURN_IF_ERROR(TensorAndDevice(&src, &srcd, &src_opd)); + if (srcd == nullptr) srcd = ctx->HostCPU(); + bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name()); + const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr; + const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr; + // both_on_cpu can be true and yet is_same_device is false, if one of src/dst + // has device type XLA_CPU, and the other CPU. + const bool both_on_cpu = src_cpu && dst_cpu; + if (is_same_device || both_on_cpu) { + dstd = dst_cpu ? nullptr : dstd; + *output = new tensorflow::TensorHandle(*src, dstd, dstd); + return tensorflow::Status::OK(); + } + if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && + !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { + return tensorflow::errors::InvalidArgument( + "Can't copy Tensor with type ", + tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(), + "."); + } + tensorflow::AllocatorAttributes attr; + if (src->dtype() == tensorflow::DT_VARIANT) { + attr.set_on_host(true); + } + tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); + if (src->shape().num_elements() == 0) { + dstd = dst_cpu ? nullptr : dstd; + *output = new tensorflow::TensorHandle(dst, dstd, dstd); + return tensorflow::Status::OK(); + } + tensorflow::DeviceContext* src_device_context = nullptr; + if (!src_cpu) { + src_device_context = srcd->tensorflow_gpu_device_info()->default_context; + } + tensorflow::DeviceContext* dst_device_context = nullptr; + if (!dst_cpu) { + dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; + } + // TODO(ashankar): The Sync() call below may be more aggressive than + // necessary. It is based on knowledge of implementation details - that + // GPU devices are implemented using 3 streams - one for host->device copies, + // one for device->host copies and one for sending operations to the GPU. + // With that setup, Sync()ing across all 3 streams should be sufficient + // but more than necessary (since it waits for operations that might have + // nothing to do with this tensor to complete). + TF_RETURN_IF_ERROR(srcd->Sync()); + tensorflow::Notification n; + tensorflow::Status status; + tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, + srcd, dstd, tensorflow::AllocatorAttributes(), + tensorflow::AllocatorAttributes(), src, &dst, + [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + if (status.ok()) { + dstd = dst_cpu ? nullptr : dstd; + *output = new tensorflow::TensorHandle(dst, dstd, dstd); + } + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h new file mode 100644 index 0000000000..eb69a13c06 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ + +#include <algorithm> +#include <cstddef> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// Associates a Tensor and a Device, used in the eager runtime. Internal version +// executor_of the TFE_TensorHandle struct and the python EagerTensor class +// (unrelated to python TensorHandle). +class TensorHandle : public core::RefCounted { + public: + TensorHandle(const Tensor& t, Device* d, Device* op_device) + : dtype(t.dtype()), + node_id(0), + tensor_(t), + device_(d), + op_device_(op_device), + ctx_(nullptr) {} + + TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx) + : dtype(dtype), + node_id(node_id), + tensor_(dtype), + device_(nullptr), + op_device_(nullptr), + ctx_(ctx) { + DCHECK_GT(node_id, 0); + } + + ~TensorHandle() override {} + + Status Tensor(const tensorflow::Tensor** t); + + Status Device(tensorflow::Device** d); + + Status OpDevice(tensorflow::Device** d); + + Status TensorAndDevice(const tensorflow::Tensor** tensor, + tensorflow::Device** device, + tensorflow::Device** op_device); + + // Note that this can be called at most once, and only on non-ready handles, + // and makes them ready. + void SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device); + + Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, + TensorHandle** output); + + // dtype for the handle. It must be the same as t.dtype() once the handle is + // ready. + const DataType dtype; + + private: + // If the contents of the Tensor pointed to by this handle is yet to be + // computed by a EagerNode, this function will block till that compuatation is + // done and the handle is "ready". + Status WaitReady(); + + bool IsReady(); + + // Id for the EagerNode that will compute the value pointed to by this handle. + // If the value is 0, the handle is already ready, but not vice-versa. + const uint64 node_id; + + tensorflow::Tensor tensor_; + + // TODO(ashankar): device_ == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('device_' should always + // be a valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* device_; + + // Device in which the op producing this tensor was executed. Equals to + // device_ for constant tensors. + tensorflow::Device* op_device_; + + mutex ctx_mutex_; + + // `ctx` is only guaranteed to be set if the handle is not "ready". This is + // typically true when the handle was produced during async execution. + // `ctx` object is not owned and should outlive this handle. + EagerContext* ctx_ GUARDED_BY(ctx_mutex_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 52b9077d8c..8473b228d3 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -185,6 +185,7 @@ class DeviceBase { virtual Allocator* GetScopedAllocator(AllocatorAttributes attr, int64 step_id) { LOG(FATAL) << "Device does not implement GetScopedAllocator()"; + return nullptr; } virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; } diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index bd9608b369..601984fcfd 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # Platform specific build config @@ -286,6 +287,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", @@ -422,7 +424,7 @@ cc_library( ]), ) -tf_cuda_cc_test( +tf_cuda_only_cc_test( name = "memory_optimizer_test", srcs = ["memory_optimizer_test.cc"], tags = ["no_cuda_on_cpu_tap"], # Do not re-enable again without actually testing. @@ -498,6 +500,7 @@ cc_library( ":constant_folding", ":custom_graph_optimizer", ":custom_graph_optimizer_registry", + ":debug_stripper", ":dependency_optimizer", ":function_optimizer", ":graph_optimizer", @@ -616,3 +619,34 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "debug_stripper", + srcs = ["debug_stripper.cc"], + hdrs = [ + "debug_stripper.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:graph_optimizer", + ], +) + +tf_cuda_cc_test( + name = "debug_stripper_test", + size = "small", + srcs = ["debug_stripper_test.cc"], + deps = [ + ":debug_stripper", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 3876486d80..792f675043 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -157,6 +158,8 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { ArithmeticOptimizer optimizer; GraphDef output; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. @@ -172,6 +175,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { EXPECT_EQ(2, new_div.input_size()); EXPECT_EQ("c1", new_div.input(0)); EXPECT_EQ("c1", new_div.input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 914a9257ee..6340565bcd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1922,6 +1922,8 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) { item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4", "concat5", "concat6", "concat7", "concat8", "concat9"}; + auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}); + EXPECT_EQ(1, tensors_expected.size()); ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); @@ -1971,9 +1973,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) { } } - auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}); auto tensors = EvaluateNodes(output, {"concat0"}); - EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc new file mode 100644 index 0000000000..461f1aa2fb --- /dev/null +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/debug_stripper.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + // TODO(haoliang): Let's remove assertions here. + *output = item.graph; + return Status::OK(); +} + +void DebugStripper::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // Takes no feedback. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.h b/tensorflow/core/grappler/optimizers/debug_stripper.h new file mode 100644 index 0000000000..1fe25aa1c3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/debug_stripper.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// DebugStripper strips off debug-related nodes (e.g. +// Assert, CheckNumerics, Print) from the graph. +class DebugStripper : public GraphOptimizer { + public: + DebugStripper() {} + ~DebugStripper() override {} + + string name() const override { return "debug_stripper"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc new file mode 100644 index 0000000000..d2cabc0798 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/debug_stripper.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class DebugStripperTest : public GrapplerTest {}; + +// TODO(haoliang): Add tests for different removal operations. +TEST_F(DebugStripperTest, OutputEqualToInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto c = ops::Const(s.WithOpName("c"), 0, {}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6eb2bbc547..47ec16226b 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" @@ -84,6 +85,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer( graph_optimizer.reset( new DependencyOptimizer(cfg_.dependency_optimization())); } + if (optimizer == "debug_stripper") { + graph_optimizer.reset(new DebugStripper()); + } return graph_optimizer; } @@ -134,10 +138,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr<GraphOptimizer>( new AutoParallel(cfg_.auto_parallel().num_replicas()))); } + if (cfg_.debug_stripper() == RewriterConfig::ON) { + optimizers.push_back( + std::unique_ptr<GraphOptimizer>(new DebugStripper())); + } } else { const std::set<string> available_optimizers = { - "pruning", "function", "constfold", "layout", "memory", - "autoparallel", "arithmetic", "loop", "dependency"}; + "pruning", "function", "constfold", "layout", + "memory", "autoparallel", "arithmetic", "loop", + "dependency", "debug_stripper"}; std::vector<string> custom_optimizer_names; for (const auto& optimizer_name : cfg_.optimizers()) { if (available_optimizers.find(optimizer_name) != @@ -238,6 +247,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.dependency_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || + cfg.debug_stripper() == RewriterConfig::ON || !cfg.optimizers().empty(); } diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 1c15ea65b8..ee126f4955 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -36,6 +36,7 @@ GrapplerTest::GrapplerTest() { cfg->set_loop_optimization(RewriterConfig::OFF); cfg->set_function_optimization(RewriterConfig::OFF); cfg->set_layout_optimizer(RewriterConfig::OFF); + cfg->set_debug_stripper(RewriterConfig::OFF); } std::vector<Tensor> GrapplerTest::EvaluateNodes( diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index b3814331ee..b2dc16d5d7 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/session.h" diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index aecad0185f..e134e476f6 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat<Index>(); auto params_flat = params->flat_outer_dims<T>(); - int64 num_updates = updates.NumElements(); - OP_REQUIRES(c, num_updates % N == 0, - errors::InvalidArgument( - "shape of indices (", indices.shape().DebugString(), - ") is not compatible with the shape of updates (", - updates.shape().DebugString(), ")")); - auto updates_flat = updates.shaped<T, 2>({N, num_updates / N}); - - functor::ScatterFunctor<Device, T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<Device>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES(c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), - " = ", indices_flat(bad_i), " is not in [0, ", - params->dim_size(0), ")")); + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + + functor::ScatterScalarFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } else { + int64 num_updates = updates.NumElements(); + OP_REQUIRES(c, num_updates % N == 0, + errors::InvalidArgument( + "shape of indices (", indices.shape().DebugString(), + ") is not compatible with the shape of updates (", + updates.shape().DebugString(), ")")); + auto updates_flat = updates.shaped<T, 2>({N, num_updates / N}); + + functor::ScatterFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params->dim_size(0), ")")); + } } } }; @@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -// TODO(apassos) add the other types here. -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \ + scatter_op::UpdateOp::SUB); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \ + scatter_op::UpdateOp::MUL); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \ + scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \ + scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \ + scatter_op::UpdateOp::MAX); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); #endif // GOOGLE_CUDA -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU #undef REGISTER_SCATTER_KERNEL #undef REGISTER_SCATTER_KERNEL_INDEX diff --git a/tensorflow/core/kernels/scatter_functor.cc b/tensorflow/core/kernels/scatter_functor.cc index 7eba82899f..cf5408123f 100644 --- a/tensorflow/core/kernels/scatter_functor.cc +++ b/tensorflow/core/kernels/scatter_functor.cc @@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // Forward declarations of the functor specializations for GPU. -#define DECLARE_GPU_SPECS_OP(T, Index, op) \ - template <> \ - Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes<T>::Matrix params, \ - typename TTypes<T>::ConstMatrix updates, \ - typename TTypes<Index>::ConstFlat indices); \ - extern template struct ScatterFunctor<GPUDevice, T, Index, op>; +#define DECLARE_GPU_SPECS_OP(T, Index, op) \ + template <> \ + Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T>::Matrix params, \ + typename TTypes<T>::ConstMatrix updates, \ + typename TTypes<Index>::ConstFlat indices); \ + extern template struct ScatterFunctor<GPUDevice, T, Index, op>; \ + template <> \ + Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \ + OpKernelContext* c, const GPUDevice& d, \ + typename TTypes<T>::Matrix params, \ + const typename TTypes<T>::ConstScalar update, \ + typename TTypes<Index>::ConstFlat indices); \ + extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>; #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DECLARE_GPU_SPECS(T) \ DECLARE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h index 079f15e101..52666645bf 100644 --- a/tensorflow/core/kernels/scatter_functor.h +++ b/tensorflow/core/kernels/scatter_functor.h @@ -18,6 +18,8 @@ limitations under the License. #include <type_traits> +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +35,7 @@ typedef Eigen::SyclDevice SYCLDevice; namespace scatter_op { -enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV }; +enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; namespace internal { @@ -45,6 +47,10 @@ struct Assign<scatter_op::UpdateOp::ASSIGN> { static void Run(Params p, Update u) { p = u; } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p.setConstant(u); + } }; template <> struct Assign<scatter_op::UpdateOp::ADD> { @@ -52,6 +58,10 @@ struct Assign<scatter_op::UpdateOp::ADD> { static void Run(Params p, Update u) { p += u; } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p + u; + } }; template <> struct Assign<scatter_op::UpdateOp::SUB> { @@ -59,6 +69,10 @@ struct Assign<scatter_op::UpdateOp::SUB> { static void Run(Params p, Update u) { p -= u; } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p + static_cast<Update>(-u); + } }; template <> struct Assign<scatter_op::UpdateOp::MUL> { @@ -66,6 +80,10 @@ struct Assign<scatter_op::UpdateOp::MUL> { static void Run(Params p, Update u) { p *= u; } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p * u; + } }; template <> struct Assign<scatter_op::UpdateOp::DIV> { @@ -73,6 +91,34 @@ struct Assign<scatter_op::UpdateOp::DIV> { static void Run(Params p, Update u) { p /= u; } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p / u; + } +}; +template <> +struct Assign<scatter_op::UpdateOp::MIN> { + // This method requires that Params and Update are tensor types. + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p = p.cwiseMin(u); + } + // Same thing, but for Update being a scalar type. + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p.cwiseMin(u); + } +}; +template <> +struct Assign<scatter_op::UpdateOp::MAX> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p = p.cwiseMax(u); + } + template <typename Params, typename Update> + static void RunScalar(Params p, Update u) { + p = p.cwiseMax(u); + } }; #ifdef TENSORFLOW_USE_SYCL @@ -117,6 +163,22 @@ struct AssignSYCL<scatter_op::UpdateOp::DIV> { p.device(d) = p / u; } }; + +template <> +struct AssignSYCL<scatter_op::UpdateOp::MIN> { + template <typename Device, typename Params, typename Update> + static void Run(Device d, Params p, Update u) { + p.device(d) = p.cwiseMin(u); + } +}; + +template <> +struct AssignSYCL<scatter_op::UpdateOp::MAX> { + template <typename Device, typename Params, typename Update> + static void Run(Device d, Params p, Update u) { + p.device(d) = p.cwiseMax(u); + } +}; #endif // TENSORFLOW_USE_SYCL } // namespace internal @@ -241,6 +303,112 @@ struct ScatterFunctorSYCL { }; #endif // TENSORFLOW_USE_SYCL +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctor { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices); +}; + +template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast<Index>(indices.size()); + const Index limit = static_cast<Index>(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign<op>::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +#ifdef TENSORFLOW_USE_SYCL +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast<Index>(indices.size()); + const Index limit = static_cast<Index>(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::AssignSYCL<op>::RunScalar( + d, params.template chip<0>(index), update); + } + return -1; + } +}; +#endif // TENSORFLOW_USE_SYCL + +template <typename T, typename Index> +struct ScatterScalarFunctorBase<CPUDevice, T, Index, + scatter_op::UpdateOp::ASSIGN> { + Index operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast<Index>(indices.size()); + const Index limit = static_cast<Index>(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. An earlier version of the + // code checked it and then grabbed it from memory a second time, which + // was a security risk since it could have changed in between. + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctor<CPUDevice, T, Index, op> + : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {}; + +#ifdef TENSORFLOW_USE_SYCL +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctorSYCL { + Index operator()(OpKernelContext* c, const SYCLDevice& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::Flat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast<Index>(indices.size()); + const Index limit = static_cast<Index>(params.dimension(0)); + for (Index i = 0; i < N; i++) { + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::AssignSYCL<op>::Run( + d, params.template chip<0>(index), update()); + } + return -1; + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc index 52972997cc..59911bf0d2 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc @@ -23,15 +23,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPECS_OP(T, Index, op) \ - template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \ + template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>; #define DEFINE_GPU_SPECS_INDEX(T, Index) \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index be18658543..70809e4dcf 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -29,12 +29,53 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; +namespace scatter_op_gpu { + +template <typename T, scatter_op::UpdateOp op> +struct ScatterOpKernelBody; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> { + __device__ void operator()(T* dest, T src) const { *dest = src; } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> { + __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> { + __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> { + __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); } +}; + template <typename T, typename Index, scatter_op::UpdateOp op> __global__ void ScatterOpCustomKernel(T* params, const T* updates, const Index* indices, Index first_dim_size, Index updates_size, Index indices_size) { Index update_block = updates_size / indices_size; + ScatterOpKernelBody<T, op> body; CUDA_1D_KERNEL_LOOP(i, updates_size) { int indices_i = i / update_block; int updates_i = i; @@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates, continue; } int params_i = param_first_index * update_block + (i % update_block); - switch (op) { - case scatter_op::UpdateOp::ASSIGN: { - params[params_i] = ldg(updates + updates_i); - break; - } - case scatter_op::UpdateOp::ADD: { - CudaAtomicAdd(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::SUB: { - CudaAtomicSub(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::MUL: { - CudaAtomicMul(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::DIV: { - CudaAtomicDiv(params + params_i, ldg(updates + updates_i)); - break; - } + body(¶ms[params_i], ldg(updates + updates_i)); + } +} + +template <typename T, typename Index, scatter_op::UpdateOp op> +__global__ void ScatterScalarOpCustomKernel(T* params, const T* update, + const Index* indices, + Index first_dim_size, + Index indices_size, + Index synthesized_updates_size) { + Index update_block = synthesized_updates_size / indices_size; + ScatterOpKernelBody<T, op> body; + CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) { + int indices_i = i / update_block; + int param_first_index = indices[indices_i]; + const T update_val = *update; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; } + int params_i = param_first_index * update_block + (i % update_block); + body(¶ms[params_i], update_val); } } +} // namespace scatter_op_gpu + namespace functor { // Specialization for a GPU device. template <typename T, typename Index, scatter_op::UpdateOp op> @@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { const Index indices_size = indices.size(); const Index updates_size = updates.size(); CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); - ScatterOpCustomKernel<T, Index, op> + scatter_op_gpu::ScatterOpCustomKernel<T, Index, op> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( params.data(), updates.data(), indices.data(), first_dim_size, updates_size, indices_size); @@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { } }; +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctor<GPUDevice, T, Index, op> { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices) { + // TODO(b/31801742): Implement indices range check. The hardest part is + // with returning a value after the range check, as we do not want to do + // device to host memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index synthesized_updates_size = indices_size * params.dimension(1); + CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); + scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + params.data(), update.data(), indices.data(), first_dim_size, + indices_size, synthesized_updates_size); + return -1; + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 282165349f..0fbde764d5 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice; // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, const Tensor& indices) { + if (updates.dims() == 0) return true; if (updates.dims() != indices.dims() + params.dims() - 1) return false; for (int d = 0; d < indices.dims(); d++) { if (updates.dim_size(d) != indices.dim_size(d)) { @@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params, params.shape().DebugString())); OP_REQUIRES( c, ValidShapes(params, updates, indices), - errors::InvalidArgument( - "Must have updates.shape = indices.shape + params.shape[1:], got ", - "updates.shape ", updates.shape().DebugString(), ", indices.shape ", - indices.shape().DebugString(), ", params.shape ", - params.shape().DebugString())); + errors::InvalidArgument("Must have updates.shape = indices.shape + " + "params.shape[1:] or updates.shape = [], got ", + "updates.shape ", updates.shape().DebugString(), + ", indices.shape ", indices.shape().DebugString(), + ", params.shape ", params.shape().DebugString())); } template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> @@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel { if (N > 0) { auto indices_flat = indices.flat<Index>(); auto params_flat = params.flat_outer_dims<T>(); - auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); - - functor::ScatterFunctor<Device, T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<Device>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape()) || + IsLegacyScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + functor::ScatterScalarFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped<T, 2>({N, updates.NumElements() / N}); + + functor::ScatterFunctor<Device, T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<Device>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel { auto indices_flat = indices_host.flat<Index>(); auto params_flat = params.flat_outer_dims<T>(); - auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); - - functor::ScatterFunctorSYCL<T, Index, op> functor; - const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), - params_flat, updates_flat, indices_flat); - OP_REQUIRES( - c, bad_i < 0, - errors::InvalidArgument( - "indices", SliceDebugString(indices.shape(), bad_i), " = ", - indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); + + if (TensorShapeUtils::IsScalar(updates.shape())) { + const auto update = updates.scalar<T>(); + + functor::ScatterScalarFunctorSYCL<T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), + params_flat, update, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } else { + auto updates_flat = + updates.shaped<T, 2>({N, updates.NumElements() / N}); + + functor::ScatterFunctorSYCL<T, Index, op> functor; + const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(), + params_flat, updates_flat, indices_flat); + OP_REQUIRES(c, bad_i < 0, + errors::InvalidArgument( + "indices", SliceDebugString(indices.shape(), bad_i), + " = ", indices_flat(bad_i), " is not in [0, ", + params.dim_size(0), ")")); + } } } }; @@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ +#define REGISTER_SCATTER_ARITHMETIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); +#define REGISTER_SCATTER_MINMAX(type, dev) \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX); + #define REGISTER_SCATTER_UPDATE(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \ scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. -#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, CPU); +#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, CPU); + +#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU); #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, GPU); +#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \ + REGISTER_SCATTER_ARITHMETIC(type, GPU); + +#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA // Registers GPU kernels. #if TENSORFLOW_USE_SYCL -#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ - REGISTER_SCATTER_ARITHEMTIC(type, SYCL); +#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \ + REGISTER_SCATTER_ARITHMETIC(type, SYCL); + +#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL); #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL); -#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_ARITHMETIC_SYCL +#undef REGISTER_SCATTER_MINMAX_SYCL #undef REGISTER_SCATTER_UPDATE_SYCL #endif // TENSORFLOW_USE_SYCL -#undef REGISTER_SCATTER_ARITHEMTIC -#undef REGISTER_SCATTER_ARITHEMTIC_CPU -#undef REGISTER_SCATTER_ARITHEMTIC_GPU +#undef REGISTER_SCATTER_ARITHMETIC +#undef REGISTER_SCATTER_ARITHMETIC_CPU +#undef REGISTER_SCATTER_ARITHMETIC_GPU +#undef REGISTER_SCATTER_MINMAX +#undef REGISTER_SCATTER_MINMAX_CPU +#undef REGISTER_SCATTER_MINMAX_GPU #undef REGISTER_SCATTER_UPDATE #undef REGISTER_SCATTER_UPDATE_CPU #undef REGISTER_SCATTER_UPDATE_GPU diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index 0b43704846..0df329310f 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -24,15 +24,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; // Instantiates functor specializations for GPU. -#define DEFINE_GPU_SPECS_OP(T, Index, op) \ - template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \ + template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>; #define DEFINE_GPU_SPECS_INDEX(T, Index) \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index 0b8645a2ae..5b3537b94c 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -185,7 +185,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -202,7 +202,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -219,7 +219,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { Status s = RunOpKernel(); EXPECT_TRUE(StringPiece(s.ToString()) .contains("Must have updates.shape = indices.shape + " - "params.shape[1:], got ")) + "params.shape[1:] or updates.shape = [], got ")) << s; } @@ -300,6 +300,20 @@ static void BM_ScatterDivInt64(int iters, int embedding_size) { BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv"); } +static void BM_ScatterMinInt32(int iters, int embedding_size) { + BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin"); +} +static void BM_ScatterMinInt64(int iters, int embedding_size) { + BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin"); +} + +static void BM_ScatterMaxInt32(int iters, int embedding_size) { + BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax"); +} +static void BM_ScatterMaxInt64(int iters, int embedding_size) { + BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax"); +} + BENCHMARK(BM_ScatterUpdateInt32) ->Arg(1) ->Arg(10) @@ -332,5 +346,11 @@ BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index b41826d6eb..05d6e02281 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -43706,6 +43706,210 @@ op { is_stateful: true } op { + name: "ResourceScatterDiv" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMax" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMin" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMul" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { name: "ResourceScatterNdUpdate" input_arg { name: "ref" @@ -43743,6 +43947,57 @@ op { is_stateful: true } op { + name: "ResourceScatterSub" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { name: "ResourceScatterUpdate" input_arg { name: "resource" @@ -48902,6 +49157,110 @@ op { } } op { + name: "ScatterMax" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } +} +op { + name: "ScatterMin" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } +} +op { name: "ScatterMul" input_arg { name: "ref" diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index af2c563489..274a7fbf75 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -21659,6 +21659,210 @@ op { is_stateful: true } op { + name: "ResourceScatterDiv" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMax" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMin" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { + name: "ResourceScatterMul" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { name: "ResourceScatterNdUpdate" input_arg { name: "ref" @@ -21696,6 +21900,57 @@ op { is_stateful: true } op { + name: "ResourceScatterSub" + input_arg { + name: "resource" + type: DT_RESOURCE + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + is_stateful: true +} +op { name: "ResourceScatterUpdate" input_arg { name: "resource" @@ -23435,6 +23690,110 @@ op { } } op { + name: "ScatterMax" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } +} +op { + name: "ScatterMin" + input_arg { + name: "ref" + type_attr: "T" + is_ref: true + } + input_arg { + name: "indices" + type_attr: "Tindices" + } + input_arg { + name: "updates" + type_attr: "T" + } + output_arg { + name: "output_ref" + type_attr: "T" + is_ref: true + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tindices" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "use_locking" + type: "bool" + default_value { + b: false + } + } +} +op { name: "ScatterMul" input_arg { name: "ref" diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 0d8cf78cc2..3d0a6c2157 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather") return Status::OK(); }); +namespace { + +Status ResourceScatterUpdateShape(InferenceContext* c) { + ShapeAndType handle_shape_and_type; + TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type)); + ShapeHandle var_shape = handle_shape_and_type.shape; + ShapeHandle indices_shape = c->input(1); + + ShapeHandle unused_updates_shape; + ShapeHandle concat; + ShapeHandle var_subshape; + TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); + TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); + TF_RETURN_IF_ERROR( + InferenceContext::Rank(c->input(2)) == 0 + ? Status::OK() + : c->Merge(c->input(2), concat, &unused_updates_shape)); + return Status::OK(); +} + +} // namespace + REGISTER_OP("ResourceScatterAdd") .Input("resource: resource") .Input("indices: Tindices") .Input("updates: dtype") .Attr("dtype: numbertype") .Attr("Tindices: {int32, int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeAndType handle_shape_and_type; - TF_RETURN_IF_ERROR( - ValidateVariableResourceHandle(c, &handle_shape_and_type)); - ShapeHandle var_shape = handle_shape_and_type.shape; - ShapeHandle indices_shape = c->input(1); + .SetShapeFn(ResourceScatterUpdateShape); - ShapeHandle unused_updates_shape; - ShapeHandle concat; - ShapeHandle var_subshape; - TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); - TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); - return Status::OK(); - }); +REGISTER_OP("ResourceScatterSub") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMul") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterDiv") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMin") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); + +REGISTER_OP("ResourceScatterMax") + .Input("resource: resource") + .Input("indices: Tindices") + .Input("updates: dtype") + .Attr("dtype: numbertype") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(ResourceScatterUpdateShape); REGISTER_OP("ResourceScatterUpdate") .Input("resource: resource") @@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate") .Input("updates: dtype") .Attr("dtype: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeAndType handle_shape_and_type; - TF_RETURN_IF_ERROR( - ValidateVariableResourceHandle(c, &handle_shape_and_type)); - ShapeHandle var_shape = handle_shape_and_type.shape; - ShapeHandle indices_shape = c->input(1); - - ShapeHandle unused_updates_shape; - ShapeHandle concat; - ShapeHandle var_subshape; - TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); - TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); - return Status::OK(); - }); + .SetShapeFn(ResourceScatterUpdateShape); REGISTER_OP("MutexV2") .Attr("container: string = ''") diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 7a524b60c0..664f52452e 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -122,7 +122,10 @@ Status ScatterUpdateShape(InferenceContext* c) { ShapeHandle var_subshape; TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); - TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); + TF_RETURN_IF_ERROR( + InferenceContext::Rank(c->input(2)) == 0 + ? Status::OK() + : c->Merge(c->input(2), concat, &unused_updates_shape)); c->set_output(0, var_shape); return Status::OK(); @@ -180,6 +183,26 @@ REGISTER_OP("ScatterDiv") .Attr("use_locking: bool = false") .SetShapeFn(ScatterUpdateShape); +REGISTER_OP("ScatterMin") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: {half, bfloat16, float, double, int32, int64}") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape); + +REGISTER_OP("ScatterMax") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: {half, bfloat16, float, double, int32, int64}") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape); + REGISTER_OP("ScatterNdUpdate") .Input("ref: Ref(T)") .Input("indices: Tindices") diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 47ddf0ccb9..9a6ff48069 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index 03c0c5ab51..8f99766e15 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -305,74 +305,6 @@ class ReadOnlyMemoryRegion { virtual uint64 length() = 0; }; -// START_SKIP_DOXYGEN - -#ifndef SWIG -// Degenerate file system that provides no implementations. -class NullFileSystem : public FileSystem { - public: - NullFileSystem() {} - - ~NullFileSystem() override = default; - - Status NewRandomAccessFile( - const string& fname, std::unique_ptr<RandomAccessFile>* result) override { - return errors::Unimplemented("NewRandomAccessFile unimplemented"); - } - - Status NewWritableFile(const string& fname, - std::unique_ptr<WritableFile>* result) override { - return errors::Unimplemented("NewWritableFile unimplemented"); - } - - Status NewAppendableFile(const string& fname, - std::unique_ptr<WritableFile>* result) override { - return errors::Unimplemented("NewAppendableFile unimplemented"); - } - - Status NewReadOnlyMemoryRegionFromFile( - const string& fname, - std::unique_ptr<ReadOnlyMemoryRegion>* result) override { - return errors::Unimplemented( - "NewReadOnlyMemoryRegionFromFile unimplemented"); - } - - Status FileExists(const string& fname) override { - return errors::Unimplemented("FileExists unimplemented"); - } - - Status GetChildren(const string& dir, std::vector<string>* result) override { - return errors::Unimplemented("GetChildren unimplemented"); - } - - Status DeleteFile(const string& fname) override { - return errors::Unimplemented("DeleteFile unimplemented"); - } - - Status CreateDir(const string& dirname) override { - return errors::Unimplemented("CreateDir unimplemented"); - } - - Status DeleteDir(const string& dirname) override { - return errors::Unimplemented("DeleteDir unimplemented"); - } - - Status GetFileSize(const string& fname, uint64* file_size) override { - return errors::Unimplemented("GetFileSize unimplemented"); - } - - Status RenameFile(const string& src, const string& target) override { - return errors::Unimplemented("RenameFile unimplemented"); - } - - Status Stat(const string& fname, FileStatistics* stat) override { - return errors::Unimplemented("Stat unimplemented"); - } -}; -#endif - -// END_SKIP_DOXYGEN - /// \brief A registry for file system implementations. /// /// Filenames are specified as an URI, which is of the form diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index abe88ab6c7..e07aad55cb 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/platform/null_file_system.h b/tensorflow/core/platform/null_file_system.h new file mode 100644 index 0000000000..008e6d54d0 --- /dev/null +++ b/tensorflow/core/platform/null_file_system.h @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +// START_SKIP_DOXYGEN + +#ifndef SWIG +// Degenerate file system that provides no implementations. +class NullFileSystem : public FileSystem { + public: + NullFileSystem() {} + + ~NullFileSystem() override = default; + + Status NewRandomAccessFile( + const string& fname, std::unique_ptr<RandomAccessFile>* result) override { + return errors::Unimplemented("NewRandomAccessFile unimplemented"); + } + + Status NewWritableFile(const string& fname, + std::unique_ptr<WritableFile>* result) override { + return errors::Unimplemented("NewWritableFile unimplemented"); + } + + Status NewAppendableFile(const string& fname, + std::unique_ptr<WritableFile>* result) override { + return errors::Unimplemented("NewAppendableFile unimplemented"); + } + + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, + std::unique_ptr<ReadOnlyMemoryRegion>* result) override { + return errors::Unimplemented( + "NewReadOnlyMemoryRegionFromFile unimplemented"); + } + + Status FileExists(const string& fname) override { + return errors::Unimplemented("FileExists unimplemented"); + } + + Status GetChildren(const string& dir, std::vector<string>* result) override { + return errors::Unimplemented("GetChildren unimplemented"); + } + + Status DeleteFile(const string& fname) override { + return errors::Unimplemented("DeleteFile unimplemented"); + } + + Status CreateDir(const string& dirname) override { + return errors::Unimplemented("CreateDir unimplemented"); + } + + Status DeleteDir(const string& dirname) override { + return errors::Unimplemented("DeleteDir unimplemented"); + } + + Status GetFileSize(const string& fname, uint64* file_size) override { + return errors::Unimplemented("GetFileSize unimplemented"); + } + + Status RenameFile(const string& src, const string& target) override { + return errors::Unimplemented("RenameFile unimplemented"); + } + + Status Stat(const string& fname, FileStatistics* stat) override { + return errors::Unimplemented("Stat unimplemented"); + } +}; +#endif + +// END_SKIP_DOXYGEN + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index fdf16aa1da..bb772460b0 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -46,6 +46,8 @@ message RewriterConfig { Toggle loop_optimization = 9; // Function optimizations (default is ON). Toggle function_optimization = 10; + // Strips debug-related nodes from the graph (off by default). + Toggle debug_stripper = 11; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md index 0d612ee0c7..ec2d877386 100644 --- a/tensorflow/docs_src/api_guides/python/state_ops.md +++ b/tensorflow/docs_src/api_guides/python/state_ops.md @@ -83,6 +83,8 @@ automatically by the optimizers in most cases. * @{tf.scatter_sub} * @{tf.scatter_mul} * @{tf.scatter_div} +* @{tf.scatter_min} +* @{tf.scatter_max} * @{tf.scatter_nd_update} * @{tf.scatter_nd_add} * @{tf.scatter_nd_sub} diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md index b706d9b204..ebeff8493b 100644 --- a/tensorflow/docs_src/community/index.md +++ b/tensorflow/docs_src/community/index.md @@ -13,3 +13,6 @@ This section contains the following documents: conventions that TensorFlow developers and users should follow. * @{$community/benchmarks$Benchmarks}, Benchmarks, a guide for defining and running a TensorFlow benchmark. + * @{$security$Using TensorFlow Securely}, which explains TensorFlow's security + model, a list of recent security reports, and information on how you can + report a security vulnerability to the TensorFlow team. diff --git a/tensorflow/docs_src/community/leftnav_files b/tensorflow/docs_src/community/leftnav_files index fab35024ad..af344506c7 100644 --- a/tensorflow/docs_src/community/leftnav_files +++ b/tensorflow/docs_src/community/leftnav_files @@ -4,3 +4,4 @@ roadmap.md documentation.md style_guide.md benchmarks.md +security.md diff --git a/tensorflow/docs_src/community/security.md b/tensorflow/docs_src/community/security.md new file mode 100644 index 0000000000..8d13c7a1ea --- /dev/null +++ b/tensorflow/docs_src/community/security.md @@ -0,0 +1,7 @@ +# Using TensorFlow Securely + +Before using TensorFlow, please take a look at our security model, list of +recent security announcements, and ways you can report security issues to the +TensorFlow team at the +[https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](Using +TensorFlow Securely) page on GitHub. diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md index b88483be69..f59cebe6c4 100644 --- a/tensorflow/docs_src/get_started/get_started_for_beginners.md +++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md @@ -14,6 +14,11 @@ If you are already familiar with basic machine learning concepts but are new to TensorFlow, read @{$premade_estimators$Getting Started with TensorFlow: for ML Experts}. +If you'd like to learn a lot about the basics of Machine Learning, +consider taking +[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/). + + ## The Iris classification problem Imagine you are a botanist seeking an automated way to classify each @@ -86,6 +91,9 @@ a number. Here's the representation scheme: * 1 represents versicolor * 2 represents virginica +For a look at other examples of labels and examples, see the +[ML Terminology section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/framing/ml-terminology). + ## Models and training @@ -371,7 +379,7 @@ There are several categories of neural networks. We'll be using a [**fully connected neural network**](https://developers.google.com/machine-learning/glossary/#fully_connected_layer), which means that the neurons in one layer take inputs from *every* neuron in -the previous layer. For example, the following figure illustrates a +the previous layer. For example, the following figure illustrates a fully connected neural network consisting of three hidden layers: * The first hidden layer contains four neurons. @@ -385,6 +393,9 @@ fully connected neural network consisting of three hidden layers: **A neural network with three hidden layers.** <p> </p> +For a more detailed introduction to neural networks, see the +[Introduction to Neural Nets section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/introduction-to-neural-networks/anatomy). + To specify a model type, instantiate an [**Estimator**](https://developers.google.com/machine-learning/glossary/#Estimators) class. TensorFlow provides two categories of Estimators: @@ -448,9 +459,9 @@ will become very important. ### Train the model -Instantiating a `tf.Estimator.DNNClassifier` creates a framework for learning -the model. Basically, we've wired a network but haven't yet let data flow -through it. To train the neural network, call the Estimator object's `train` +Instantiating a `tf.Estimator.DNNClassifier` creates a framework for learning +the model. Basically, we've wired a network but haven't yet let data flow +through it. To train the neural network, call the Estimator object's `train` method. For example: ```python @@ -559,15 +570,15 @@ of 0.5. The following suggests a more effective model: <th colspan="1">Label</th> <th colspan="1">Prediction</th> </tr> - <tr> <td>5.9</td> <td>3.0</td> <td>4.3</td> <td>1.5</td> <td>1</td> + <tr> <td>5.9</td> <td>3.0</td> <td>4.3</td> <td>1.5</td> <td>1</td> <td style="background-color:green">1</td></tr> - <tr> <td>6.9</td> <td>3.1</td> <td>5.4</td> <td>2.1</td> <td>2</td> + <tr> <td>6.9</td> <td>3.1</td> <td>5.4</td> <td>2.1</td> <td>2</td> <td style="background-color:green">2</td></tr> - <tr> <td>5.1</td> <td>3.3</td> <td>1.7</td> <td>0.5</td> <td>0</td> + <tr> <td>5.1</td> <td>3.3</td> <td>1.7</td> <td>0.5</td> <td>0</td> <td style="background-color:green">0</td></tr> - <tr> <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td>1</td> + <tr> <td>6.0</td> <td>3.4</td> <td>4.5</td> <td>1.6</td> <td>1</td> <td style="background-color:red">2</td></tr> - <tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td> + <tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td> <td style="background-color:green">1</td></tr> </table> @@ -631,6 +642,10 @@ Test set accuracy: 0.967 An accuracy of 0.967 implies that our trained model correctly classified 29 out of the 30 Iris species in the test set. +To get a deeper understanding of different metrics for evaluating +models, see the +[Classification section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/classification). + ### Predicting @@ -723,7 +738,6 @@ Prediction is "Virginica" (97.9%), expected "Virginica" ## Summary -<!--TODO(barryr): When MLCC is released, add pointers to relevant sections.--> This document provides a short introduction to machine learning. Because `premade_estimators.py` relies on high-level APIs, much of the diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md index b7bd1286e3..fb83a770a5 100644 --- a/tensorflow/docs_src/get_started/index.md +++ b/tensorflow/docs_src/get_started/index.md @@ -1,5 +1,12 @@ # Getting Started +If you are new to machine learning, we recommend taking the following online +course prior to diving into TensorFlow documentation: + + * [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), + which introduces machine learning concepts and encourages experimentation + with existing TensorFlow code. + TensorFlow is a tool for machine learning. While it contains a wide range of functionality, TensorFlow is mainly designed for deep neural network models. diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 8e46c9ee20..27b696696d 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -506,11 +506,18 @@ TensorFlow programs: <pre>Hello, TensorFlow!</pre> -If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + + ## Common installation problems We are relying on Stack Overflow to document TensorFlow installation problems diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index cb7250a16e..7060ef43da 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -400,12 +400,18 @@ writing TensorFlow programs: <pre>Hello, TensorFlow!</pre> -If you are new to TensorFlow, see -@{$get_started/premade_estimators$Getting Started with TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + + ## Common installation problems We are relying on Stack Overflow to document TensorFlow installation problems diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 2413bc9cfb..86add74da1 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -17,7 +17,7 @@ You must choose one of the following types of TensorFlow to install: NVIDIA® GPU, you must install this version. Note that this version of TensorFlow is typically much easier to install (typically, in 5 or 10 minutes), so even if you have an NVIDIA GPU, we recommend - installing this version first. Prebuilt binaries will use AVX instructions. + installing this version first. Prebuilt binaries will use AVX instructions. * **TensorFlow with GPU support**. TensorFlow programs typically run significantly faster on a GPU than on a CPU. Therefore, if your system has a NVIDIA® GPU meeting the prerequisites shown below @@ -154,13 +154,17 @@ TensorFlow programs: <pre>Hello, TensorFlow!</pre> -If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. - If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -There is also a helpful [script](https://gist.github.com/mrry/ee5dbcfdd045fa48a27d56664411d41c) -for Windows TensorFlow installation issues. +If you are new to machine learning, we recommend the following: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) +* @{$get_started/get_started_for_beginners$Getting Started for ML Beginners} + +If you are experienced with machine learning but new to TensorFlow, see +@{$get_started/premade_estimators$Getting Started with TensorFlow}. + ## Common installation problems diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md index e8027fc12b..d5703e0737 100644 --- a/tensorflow/docs_src/programmers_guide/embedding.md +++ b/tensorflow/docs_src/programmers_guide/embedding.md @@ -7,6 +7,9 @@ with the TensorBoard Embedding Projector newcomers to machine learning or TensorFlow, and the Embedding Projector how-to is for users at all levels. +An alternative tutorial on these concepts is available in the +[Embeddings section of Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/embeddings/video-lecture). + [TOC] An **embedding** is a mapping from discrete objects, such as words, to vectors diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 5ddd32ed48..838f4f2301 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -1089,186 +1089,232 @@ func ExpandDims(scope *Scope, input tf.Output, axis tf.Output) (output tf.Output return op.Output(0) } -// Returns (x - y)(x - y) element-wise. +// A placeholder op that passes through `input` when its output is not fed. // -// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// input: The default value to produce when `output` is not fed. +// shape: The (possibly partial) shape of the tensor. +// +// Returns A placeholder tensor that defaults to `input` if it is not fed. +func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "SquaredDifference", + Type: "PlaceholderWithDefault", Input: []tf.Input{ - x, y, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Forwards the input to the output. +// A placeholder op for a value that will be fed into the computation. // -// This operator represents the loop termination condition used by the -// "pivot" switches of a loop. +// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2. +// +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. // // Arguments: -// input: A boolean scalar, representing the branch predicate of the Switch op. +// dtype: The type of elements in the tensor. +// shape: The shape of the tensor. The shape can be any partially-specified +// shape. To be unconstrained, pass in a shape with unknown rank. // -// Returns The same tensor as `input`. -func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} opspec := tf.OpSpec{ - Type: "LoopCond", - Input: []tf.Input{ - input, - }, + Type: "PlaceholderV2", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizedMulAttr is an optional argument to QuantizedMul. -type QuantizedMulAttr func(optionalAttr) +// PlaceholderAttr is an optional argument to Placeholder. +type PlaceholderAttr func(optionalAttr) -// QuantizedMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { +// PlaceholderShape sets the optional shape attribute to value. +// +// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the +// shape is unconstrained. +// If not specified, defaults to <unknown_rank:true > +func PlaceholderShape(value tf.Shape) PlaceholderAttr { return func(m optionalAttr) { - m["Toutput"] = value + m["shape"] = value } } -// Returns x * y element-wise, working on quantized buffers. -// -// Arguments: -// +// A placeholder op for a value that will be fed into the computation. // -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// Arguments: +// dtype: The type of elements in the tensor. // -// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizedMul", - Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, - }, + Type: "Placeholder", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. -type QuantizedMatMulAttr func(optionalAttr) - -// QuantizedMatMulToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } + return op.Output(0) } -// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. // -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// This operation folds the padded areas of `input` by `MirrorPad` according to the +// `paddings` you specify. `paddings` must be the same as `paddings` argument +// given to the corresponding `MirrorPad` op. // -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. +// The folded size of each dimension D of the output is: // -// value: The type of output produced by activation function -// following this operation. -// If not specified, defaults to DT_QUINT8 -func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { - return func(m optionalAttr) { - m["Tactivation"] = value +// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. +// # 'paddings' is [[0, 1]], [0, 1]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[ 1, 5] +// [11, 28]] +// ``` +// +// Arguments: +// input: The input tensor to be folded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: The mode used in the `MirrorPad` op. +// +// Returns The folded tensor. +func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { + if scope.Err() != nil { + return } + attrs := map[string]interface{}{"mode": mode} + opspec := tf.OpSpec{ + Type: "MirrorPadGrad", + Input: []tf.Input{ + input, paddings, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Perform a quantized matrix multiplication of `a` by the matrix `b`. +// Pads a tensor with mirrored values. // -// The inputs must be two-dimensional matrices and the inner dimension of -// `a` (after being transposed if `transpose_a` is non-zero) must match the -// outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). +// This operation pads a `input` with mirrored values according to the `paddings` +// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many values to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many values to add after the contents of `input` +// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater +// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true +// (if false, respectively). +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6]]. +// # 'paddings' is [[1, 1]], [2, 2]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] +// [2, 1, 1, 2, 3, 3, 2] +// [5, 4, 4, 5, 6, 6, 5] +// [5, 4, 4, 5, 6, 6, 5]] +// ``` // // Arguments: -// a: Must be a two-dimensional tensor. -// b: Must be a two-dimensional tensor. -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. +// input: The input tensor to be padded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions +// do not include the borders, while in symmetric mode the padded regions +// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` +// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and +// it is `[1, 2, 3, 3, 2]` in symmetric mode. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { +// Returns The padded tensor. +func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"mode": mode} opspec := tf.OpSpec{ - Type: "QuantizedMatMul", + Type: "MirrorPad", Input: []tf.Input{ - a, b, min_a, max_a, min_b, max_b, + input, paddings, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// A placeholder op that passes through `input` when its output is not fed. +// Pads a tensor. // -// Arguments: -// input: The default value to produce when `output` is not fed. -// shape: The (possibly partial) shape of the tensor. +// This operation pads `input` according to the `paddings` and `constant_values` +// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many padding values to add before the contents of `input` in that dimension, +// and `paddings[D, 1]` indicates how many padding values to add after the contents +// of `input` in that dimension. `constant_values` is a scalar tensor of the same +// type as `input` that indicates the value to use for padding `input`. // -// Returns A placeholder tensor that defaults to `input` if it is not fed. -func PlaceholderWithDefault(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) { +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # 'constant_values' is 0 +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "PlaceholderWithDefault", + Type: "PadV2", Input: []tf.Input{ - input, + input, paddings, constant_values, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -2063,6 +2109,47 @@ func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true i return op.Output(0), op.Output(1), op.Output(2) } +// Returns (x - y)(x - y) element-wise. +// +// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SquaredDifference", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Forwards the input to the output. +// +// This operator represents the loop termination condition used by the +// "pivot" switches of a loop. +// +// Arguments: +// input: A boolean scalar, representing the branch predicate of the Switch op. +// +// Returns The same tensor as `input`. +func LoopCond(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LoopCond", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ApproximateEqualAttr is an optional argument to ApproximateEqual. type ApproximateEqualAttr func(optionalAttr) @@ -2391,50 +2478,6 @@ func Sign(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// QuantizedAddAttr is an optional argument to QuantizedAdd. -type QuantizedAddAttr func(optionalAttr) - -// QuantizedAddToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// Returns x + y element-wise, working on quantized buffers. -// -// Arguments: -// -// -// min_x: The float value that the lowest quantized `x` value represents. -// max_x: The float value that the highest quantized `x` value represents. -// min_y: The float value that the lowest quantized `y` value represents. -// max_y: The float value that the highest quantized `y` value represents. -// -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -// -// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about -// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedAdd", - Input: []tf.Input{ - x, y, min_x, max_x, min_y, max_y, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // ArgMinAttr is an optional argument to ArgMin. type ArgMinAttr func(optionalAttr) @@ -3741,32 +3784,6 @@ func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { return op.Output(0) } -// Given a quantized tensor described by (input, input_min, input_max), outputs a -// -// range that covers the actual values present in that tensor. This op is -// typically used to produce the requested_output_min and requested_output_max for -// Requantize. -// -// Arguments: -// -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// -// Returns The computed min output.the computed max output. -func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RequantizationRange", - Input: []tf.Input{ - input, input_min, input_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - // Returns the truth value of (x <= y) element-wise. // // *NOTE*: `LessEqual` supports broadcasting. More about broadcasting @@ -3943,46 +3960,6 @@ func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMul return op.Output(0) } -// Pads a tensor. -// -// This operation pads `input` according to the `paddings` and `constant_values` -// you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many padding values to add before the contents of `input` in that dimension, -// and `paddings[D, 1]` indicates how many padding values to add after the contents -// of `input` in that dimension. `constant_values` is a scalar tensor of the same -// type as `input` that indicates the value to use for padding `input`. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # 'constant_values' is 0 -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -func PadV2(scope *Scope, input tf.Output, paddings tf.Output, constant_values tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "PadV2", - Input: []tf.Input{ - input, paddings, constant_values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns which elements of x are NaN. // // @compatibility(numpy) @@ -4292,52 +4269,6 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. -// -// Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes gradients of the maxpooling function. // // Arguments: @@ -5247,50 +5178,6 @@ func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. -// -// This operation folds the padded areas of `input` by `MirrorPad` according to the -// `paddings` you specify. `paddings` must be the same as `paddings` argument -// given to the corresponding `MirrorPad` op. -// -// The folded size of each dimension D of the output is: -// -// `input.dim_size(D) - paddings(D, 0) - paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 2, 3], [4, 5, 6], [7, 8, 9]]. -// # 'paddings' is [[0, 1]], [0, 1]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[ 1, 5] -// [11, 28]] -// ``` -// -// Arguments: -// input: The input tensor to be folded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: The mode used in the `MirrorPad` op. -// -// Returns The folded tensor. -func MirrorPadGrad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode} - opspec := tf.OpSpec{ - Type: "MirrorPadGrad", - Input: []tf.Input{ - input, paddings, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // BiasAddGradAttr is an optional argument to BiasAddGrad. type BiasAddGradAttr func(optionalAttr) @@ -5411,239 +5298,95 @@ func FusedBatchNormV2(scope *Scope, x tf.Output, scale tf.Output, offset tf.Outp return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// AvgPoolGradAttr is an optional argument to AvgPoolGrad. -type AvgPoolGradAttr func(optionalAttr) - -// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// Returns the rank of a tensor. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the average pooling function. +// This operation returns an integer representing the rank of `input`. // -// Arguments: -// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. -// the output of `avg_pool`. -// ksize: The size of the sliding window for each dimension of the input. -// strides: The stride of the sliding window for each dimension of the input. -// padding: The type of padding algorithm to use. +// For example: // -// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. -func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { +// ``` +// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] +// # shape of tensor 't' is [2, 2, 3] +// rank(t) ==> 3 +// ``` +// +// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank +// of a tensor is the number of indices required to uniquely select each element +// of the tensor. Rank is also known as "order", "degree", or "ndims." +func Rank(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "AvgPoolGrad", + Type: "Rank", Input: []tf.Input{ - orig_input_shape, grad, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StageClearAttr is an optional argument to StageClear. -type StageClearAttr func(optionalAttr) - -// StageClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageClearCapacity(value int64) StageClearAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StageClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageClearMemoryLimit(value int64) StageClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// StageClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StageClearContainer(value string) StageClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// StageClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StageClearSharedName(value string) StageClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. -// -// Returns the created operation. -func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StageClear", - - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. -type ComputeAccidentalHitsAttr func(optionalAttr) - -// ComputeAccidentalHitsSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Computes the ids of the positions in sampled_candidates that match true_labels. -// -// When doing log-odds NCE, the result of this op should be passed through a -// SparseToDense op, then added to the logits of the sampled candidates. This has -// the effect of 'removing' the sampled labels that match the true labels by -// making the classifier sure that they are sampled labels. +// Transforms a Tensor into a serialized TensorProto proto. // // Arguments: -// true_classes: The true_classes output of UnpackSparseLabels. -// sampled_candidates: The sampled_candidates output of CandidateSampler. -// num_true: Number of true labels per context. +// tensor: A Tensor of type `T`. // -// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label -// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element -// is -FLOAT_MAX. -func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ComputeAccidentalHits", + Type: "SerializeTensor", Input: []tf.Input{ - true_classes, sampled_candidates, + tensor, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. -type TensorArrayGatherV3Attr func(optionalAttr) +// MatrixSolveAttr is an optional argument to MatrixSolve. +type MatrixSolveAttr func(optionalAttr) -// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// MatrixSolveAdjoint sets the optional adjoint attribute to value. // -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to <unknown_rank:true > -func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { +// value: Boolean indicating whether to solve with `matrix` or its (block-wise) +// adjoint. +// If not specified, defaults to false +func MatrixSolveAdjoint(value bool) MatrixSolveAttr { return func(m optionalAttr) { - m["element_shape"] = value + m["adjoint"] = value } } -// Gather specific elements from the TensorArray into output `value`. +// Solves systems of linear equations. // -// All elements selected by `indices` must have the same shape. +// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +// If `adjoint` is `True` then each output matrix satisfies +// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. // // Arguments: -// handle: The handle to a TensorArray. -// indices: The locations in the TensorArray from which to read tensor elements. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// matrix: Shape is `[..., M, M]`. +// rhs: Shape is `[..., M, K]`. // -// Returns All of the elements in the TensorArray, concatenated along a new -// axis (the new dimension 0). -func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { +// Returns Shape is `[..., M, K]`. +func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayGatherV3", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process and will never change. However, it is not suitable for cryptography. -// This function may be used when CPU time is scarce and inputs are trusted or -// unimportant. There is a risk of adversaries constructing inputs that all hash -// to the same bucket. To prevent this problem, use a strong hash function with -// `tf.string_to_hash_bucket_strong`. -// -// Arguments: -// input: The strings to assign a hash bucket. -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_buckets": num_buckets} - opspec := tf.OpSpec{ - Type: "StringToHashBucketFast", + Type: "MatrixSolve", Input: []tf.Input{ - input, + matrix, rhs, }, Attrs: attrs, } @@ -5651,18 +5394,15 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o return op.Output(0) } -// Returns the max of x and y (i.e. x > y ? x : y) element-wise. -// -// *NOTE*: `Maximum` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes acos of x element-wise. +func Acos(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Maximum", + Type: "Acos", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) @@ -5707,6 +5447,76 @@ func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output return op.Output(0) } +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to <i:1 i:1 i:1 i:1 > +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNativeBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LRNGradAttr is an optional argument to LRNGrad. type LRNGradAttr func(optionalAttr) @@ -6236,6 +6046,79 @@ func Tan(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) + +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// +// That is for rows we have grad for, we update var, accum and linear as follows: +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyFtrl", + Input: []tf.Input{ + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Returns which elements of x are Inf. +// +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the sum along sparse segments of a tensor divided by the sqrt of N. // // N is the size of the segment being reduced. @@ -6918,6 +6801,170 @@ func ResourceScatterUpdate(scope *Scope, resource tf.Output, indices tf.Output, return scope.AddOperation(opspec) } +// AvgPoolGradAttr is an optional argument to AvgPoolGrad. +type AvgPoolGradAttr func(optionalAttr) + +// AvgPoolGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func AvgPoolGradDataFormat(value string) AvgPoolGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the average pooling function. +// +// Arguments: +// orig_input_shape: 1-D. Shape of the original input to `avg_pool`. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. +// the output of `avg_pool`. +// ksize: The size of the sliding window for each dimension of the input. +// strides: The stride of the sliding window for each dimension of the input. +// padding: The type of padding algorithm to use. +// +// Returns 4-D. Gradients w.r.t. the input of `avg_pool`. +func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPoolGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "AvgPoolGrad", + Input: []tf.Input{ + orig_input_shape, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StageClearAttr is an optional argument to StageClear. +type StageClearAttr func(optionalAttr) + +// StageClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageClearCapacity(value int64) StageClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// StageClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StageClearMemoryLimit(value int64) StageClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// StageClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StageClearContainer(value string) StageClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// StageClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StageClearSharedName(value string) StageClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. +// +// Returns the created operation. +func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StageClear", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits. +type ComputeAccidentalHitsAttr func(optionalAttr) + +// ComputeAccidentalHitsSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Computes the ids of the positions in sampled_candidates that match true_labels. +// +// When doing log-odds NCE, the result of this op should be passed through a +// SparseToDense op, then added to the logits of the sampled candidates. This has +// the effect of 'removing' the sampled labels that match the true labels by +// making the classifier sure that they are sampled labels. +// +// Arguments: +// true_classes: The true_classes output of UnpackSparseLabels. +// sampled_candidates: The sampled_candidates output of CandidateSampler. +// num_true: Number of true labels per context. +// +// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label +// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element +// is -FLOAT_MAX. +func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_true": num_true} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ComputeAccidentalHits", + Input: []tf.Input{ + true_classes, sampled_candidates, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // CumsumAttr is an optional argument to Cumsum. type CumsumAttr func(optionalAttr) @@ -7314,79 +7361,6 @@ func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, return op.Output(0) } -// Generates values in an interval. -// -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. -// -// For example: -// -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` -// -// Arguments: -// start: First entry in the range. -// stop: Last entry in the range. -// num: Number of values to generate. -// -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LinSpace", - Input: []tf.Input{ - start, stop, num, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) - -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. -// -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { - return func(m optionalAttr) { - m["ignore_lookup_error"] = value - } -} - -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. -// -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DestroyResourceOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // Applies softmax to a batched N-D `SparseTensor`. // // The inputs represent an N-D SparseTensor with logical shape `[..., B, C]` @@ -7822,6 +7796,79 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Generates values in an interval. +// +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. +// +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` +// +// Arguments: +// start: First entry in the range. +// stop: Last entry in the range. +// num: Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LinSpace", + Input: []tf.Input{ + start, stop, num, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) + +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. +// +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { + return func(m optionalAttr) { + m["ignore_lookup_error"] = value + } +} + +// Deletes the resource specified by the handle. +// +// All subsequent operations using the resource will result in a NotFound +// error status. +// +// Arguments: +// resource: handle to the resource to delete. +// +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DestroyResourceOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // LRNAttr is an optional argument to LRN. type LRNAttr func(optionalAttr) @@ -8054,6 +8101,65 @@ func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...Resi return op.Output(0) } +// Pads a tensor with zeros. +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` +// in that dimension. +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 1], [2, 2]] +// # 'paddings' is [[1, 1], [2, 2]] +// # rank of 't' is 2 +// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] +// [0, 0, 1, 1, 0, 0] +// [0, 0, 2, 2, 0, 0] +// [0, 0, 0, 0, 0, 0]] +// ``` +func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Pad", + Input: []tf.Input{ + input, paddings, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Checks whether a resource handle-based variable has been initialized. +// +// Arguments: +// resource: the input resource handle. +// +// Returns a scalar boolean which is true if the variable has been +// initialized. +func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "VarIsInitializedOp", + Input: []tf.Input{ + resource, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. type StatelessRandomUniformAttr func(optionalAttr) @@ -8098,6 +8204,38 @@ func StatelessRandomUniform(scope *Scope, shape tf.Output, seed tf.Output, optio return op.Output(0) } +// Makes its input available to the next iteration. +// +// Arguments: +// data: The tensor to be made available to the next iteration. +// +// Returns The same tensor as `data`. +func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NextIteration", + Input: []tf.Input{ + data, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fact", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AngleAttr is an optional argument to Angle. type AngleAttr func(optionalAttr) @@ -8672,79 +8810,6 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output return op.Output(0) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) - -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. -// -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", - Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Returns which elements of x are Inf. -// -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IsInf", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. type ResourceSparseApplyRMSPropAttr func(optionalAttr) @@ -8974,6 +9039,100 @@ func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_box return op.Output(0), op.Output(1), op.Output(2) } +// Converts each string in the input Tensor to its hash mod by a number of buckets. +// +// The hash function is deterministic on the content of the string within the +// process and will never change. However, it is not suitable for cryptography. +// This function may be used when CPU time is scarce and inputs are trusted or +// unimportant. There is a risk of adversaries constructing inputs that all hash +// to the same bucket. To prevent this problem, use a strong hash function with +// `tf.string_to_hash_bucket_strong`. +// +// Arguments: +// input: The strings to assign a hash bucket. +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_buckets": num_buckets} + opspec := tf.OpSpec{ + Type: "StringToHashBucketFast", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the max of x and y (i.e. x > y ? x : y) element-wise. +// +// *NOTE*: `Maximum` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Maximum", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. +type TensorArrayGatherV3Attr func(optionalAttr) + +// TensorArrayGatherV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to <unknown_rank:true > +func TensorArrayGatherV3ElementShape(value tf.Shape) TensorArrayGatherV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// Gather specific elements from the TensorArray into output `value`. +// +// All elements selected by `indices` must have the same shape. +// +// Arguments: +// handle: The handle to a TensorArray. +// indices: The locations in the TensorArray from which to read tensor elements. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. +// +// Returns All of the elements in the TensorArray, concatenated along a new +// axis (the new dimension 0). +func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV3Attr) (value tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayGatherV3", + Input: []tf.Input{ + handle, indices, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns x / y element-wise for integer types. // // Truncation designates that negative numbers will round fractional quantities @@ -9048,6 +9207,30 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and return tensors } +// Creates a dataset that skips `count` elements from the `input_dataset`. +// +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be skipped. If count is -1, skips everything. +// +// +func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SkipDataset", + Input: []tf.Input{ + input_dataset, count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the maximum along segments of a tensor. // // Read @{$math_ops#segmentation$the section on segmentation} for an explanation of @@ -9084,30 +9267,6 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf. return op.Output(0) } -// Creates a dataset that skips `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be skipped. If count is -1, skips everything. -// -// -func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SkipDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes hyperbolic tangent of `x` element-wise. func Tanh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -9861,6 +10020,79 @@ func FFT(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// +// Arguments: +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. +// +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "ParseTensor", + Input: []tf.Input{ + serialized, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) + +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. +// If not specified, defaults to DT_INT64 +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { + return func(m optionalAttr) { + m["Targmax"] = value + } +} + +// Performs max pooling on the input and outputs both max values and indices. +// +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index +// `((b * height + y) * width + x) * channels + c`. +// +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. +// +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolWithArgmax", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. type ResourceSparseApplyAdagradDAAttr func(optionalAttr) @@ -11004,73 +11236,6 @@ func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, padd return op.Output(0) } -// Transforms a Tensor into a serialized TensorProto proto. -// -// Arguments: -// tensor: A Tensor of type `T`. -// -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SerializeTensor", - Input: []tf.Input{ - tensor, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveAttr is an optional argument to MatrixSolve. -type MatrixSolveAttr func(optionalAttr) - -// MatrixSolveAdjoint sets the optional adjoint attribute to value. -// -// value: Boolean indicating whether to solve with `matrix` or its (block-wise) -// adjoint. -// If not specified, defaults to false -func MatrixSolveAdjoint(value bool) MatrixSolveAttr { - return func(m optionalAttr) { - m["adjoint"] = value - } -} - -// Solves systems of linear equations. -// -// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is -// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix -// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. -// If `adjoint` is `True` then each output matrix satisfies -// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. -// -// Arguments: -// matrix: Shape is `[..., M, M]`. -// rhs: Shape is `[..., M, K]`. -// -// Returns Shape is `[..., M, K]`. -func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolve", - Input: []tf.Input{ - matrix, rhs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Inverse 3D fast Fourier transform. // // Computes the inverse 3-dimensional discrete Fourier transform over the @@ -12025,6 +12190,46 @@ func AddManySparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_va return op.Output(0) } +// Concatenates tensors along one dimension. +// +// Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatV2", + Input: []tf.Input{ + tf.OutputList(values), axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Reads and outputs the entire contents of the input filename. +func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReadFile", + Input: []tf.Input{ + filename, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // MinAttr is an optional argument to Min. type MinAttr func(optionalAttr) @@ -12088,76 +12293,6 @@ func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) { return op.Output(0) } -// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. -type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) - -// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to <i:1 i:1 i:1 i:1 > -func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of depthwise convolution with respect to the filter. -// -// Arguments: -// input: 4-D with shape based on `data_format`. For example, if -// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, -// in_width, in_channels]` tensor. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropFilter", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes sigmoid of `x` element-wise. // // Specifically, `y = 1 / (1 + exp(-x))`. @@ -12888,6 +13023,140 @@ func TensorArrayV2(scope *Scope, size tf.Output, dtype tf.DataType, optional ... return op.Output(0) } +// DecodeCSVAttr is an optional argument to DecodeCSV. +type DecodeCSVAttr func(optionalAttr) + +// DecodeCSVFieldDelim sets the optional field_delim attribute to value. +// +// value: char delimiter to separate fields in a record. +// If not specified, defaults to "," +func DecodeCSVFieldDelim(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["field_delim"] = value + } +} + +// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. +// +// value: If false, treats double quotation marks as regular +// characters inside of the string fields (ignoring RFC 4180, Section 2, +// Bullet 5). +// If not specified, defaults to true +func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { + return func(m optionalAttr) { + m["use_quote_delim"] = value + } +} + +// DecodeCSVNaValue sets the optional na_value attribute to value. +// +// value: Additional string to recognize as NA/NaN. +// If not specified, defaults to "" +func DecodeCSVNaValue(value string) DecodeCSVAttr { + return func(m optionalAttr) { + m["na_value"] = value + } +} + +// Convert CSV records to tensors. Each column maps to one tensor. +// +// RFC 4180 format is expected for the CSV records. +// (https://tools.ietf.org/html/rfc4180) +// Note that we allow leading and trailing spaces with int or float field. +// +// Arguments: +// records: Each string is a record/row in the csv and all records should have +// the same format. +// record_defaults: One tensor per column of the input record, with either a +// scalar default value for that column or empty if the column is required. +// +// Returns Each tensor will have the same shape as records. +func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeCSV", + Input: []tf.Input{ + records, tf.OutputList(record_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("DecodeCSV", err) + return + } + return output +} + +// MapClearAttr is an optional argument to MapClear. +type MapClearAttr func(optionalAttr) + +// MapClearCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapClearCapacity(value int64) MapClearAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapClearMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapClearMemoryLimit(value int64) MapClearAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapClearContainer(value string) MapClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapClearSharedName(value string) MapClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. +// +// Returns the created operation. +func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapClear", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // ThreadUnsafeUnigramCandidateSamplerAttr is an optional argument to ThreadUnsafeUnigramCandidateSampler. type ThreadUnsafeUnigramCandidateSamplerAttr func(optionalAttr) @@ -13007,78 +13276,6 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output return op.Output(0) } -// Deprecated. Use TensorArrayReadV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 -func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - opspec := tf.OpSpec{ - Type: "TensorArrayReadV2", - Input: []tf.Input{ - handle, index, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Does nothing. Serves as a control trigger for scheduling. -// -// Only useful as a placeholder for control edges. -// -// Returns the created operation. -func ControlTrigger(scope *Scope) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ControlTrigger", - } - return scope.AddOperation(opspec) -} - -// Batch normalization. -// -// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() -// -// This op is deprecated. Prefer `tf.nn.batch_normalization`. -// -// Arguments: -// t: A 4D input Tensor. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} - opspec := tf.OpSpec{ - Type: "BatchNormWithGlobalNormalization", - Input: []tf.Input{ - t, m, v, beta, gamma, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2. type MutableDenseHashTableV2Attr func(optionalAttr) @@ -13375,65 +13572,6 @@ func TextLineDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp return op.Output(0) } -// Checks whether a resource handle-based variable has been initialized. -// -// Arguments: -// resource: the input resource handle. -// -// Returns a scalar boolean which is true if the variable has been -// initialized. -func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "VarIsInitializedOp", - Input: []tf.Input{ - resource, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Pads a tensor with zeros. -// -// This operation pads a `input` with zeros according to the `paddings` you -// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many zeros to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many zeros to add after the contents of `input` -// in that dimension. -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 1], [2, 2]] -// # 'paddings' is [[1, 1], [2, 2]] -// # rank of 't' is 2 -// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] -// [0, 0, 1, 1, 0, 0] -// [0, 0, 2, 2, 0, 0] -// [0, 0, 0, 0, 0, 0]] -// ``` -func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Pad", - Input: []tf.Input{ - input, paddings, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes gradients for SparseSegmentMean. // // Returns tensor "output" with same shape as grad, except for dimension 0 whose @@ -14876,6 +15014,101 @@ func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) { return op.Output(0) } +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) + +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { + return func(m optionalAttr) { + m["fast"] = value + } +} + +// Solves one or more linear least-squares problems. +// +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +// type as `matrix` and shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations +// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` +// in the least squares sense. +// +// We use the following notation for (complex) matrix and right-hand sides +// in the batch: +// +// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +// `output`=\\(X \in \mathbb{C}^{n \times k}\\), +// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + +// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +// when \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. +// +// Arguments: +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. +// +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolveLs", + Input: []tf.Input{ + matrix, rhs, l2_regularizer, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Elementwise computes the bitwise OR of `x` and `y`. +// +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseOr", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. type SparseToSparseSetOperationAttr func(optionalAttr) @@ -15174,6 +15407,52 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype return op.Output(0), op.Output(1), op.Output(2) } +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) + +// MaxPoolDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs max pooling on the input. +// +// Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Says whether the targets are in the top `K` predictions. // // This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the @@ -16313,83 +16592,6 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) - -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. -// -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { - return func(m optionalAttr) { - m["upper_frequency_limit"] = value - } -} - -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. -// -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { - return func(m optionalAttr) { - m["lower_frequency_limit"] = value - } -} - -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. -// -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["filterbank_channel_count"] = value - } -} - -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. -// -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["dct_coefficient_count"] = value - } -} - -// Transforms a spectrogram into a form that's useful for speech recognition. -// -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. -// -// Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Mfcc", - Input: []tf.Input{ - spectrogram, sample_rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the element-wise sum of a list of tensors. // // `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not @@ -17022,87 +17224,129 @@ func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_uppe return op.Output(0) } -// Computes acos of x element-wise. -func Acos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Acos", - Input: []tf.Input{ - x, - }, +// QuantizedMatMulAttr is an optional argument to QuantizedMatMul. +type QuantizedMatMulAttr func(optionalAttr) + +// QuantizedMatMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulToutput(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Toutput"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) +// QuantizedMatMulTransposeA sets the optional transpose_a attribute to value. +// +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeA(value bool) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { +// QuantizedMatMulTransposeB sets the optional transpose_b attribute to value. +// +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulTransposeB(value bool) QuantizedMatMulAttr { return func(m optionalAttr) { - m["Targmax"] = value + m["transpose_b"] = value } } -// Performs max pooling on the input and outputs both max values and indices. +// QuantizedMatMulTactivation sets the optional Tactivation attribute to value. // -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index -// `((b * height + y) * width + x) * channels + c`. +// value: The type of output produced by activation function +// following this operation. +// If not specified, defaults to DT_QUINT8 +func QuantizedMatMulTactivation(value tf.DataType) QuantizedMatMulAttr { + return func(m optionalAttr) { + m["Tactivation"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b`. // -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. +// The inputs must be two-dimensional matrices and the inner dimension of +// `a` (after being transposed if `transpose_a` is non-zero) must match the +// outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). // // Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// a: Must be a two-dimensional tensor. +// b: Must be a two-dimensional tensor. +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. // -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedMatMul(scope *Scope, a tf.Output, b tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", + Type: "QuantizedMatMul", Input: []tf.Input{ - input, + a, b, min_a, max_a, min_b, max_b, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0), op.Output(1), op.Output(2) } -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. +// Does nothing. Serves as a control trigger for scheduling. // -// Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. +// Only useful as a placeholder for control edges. // -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { +// Returns the created operation. +func ControlTrigger(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "ParseTensor", + Type: "ControlTrigger", + } + return scope.AddOperation(opspec) +} + +// Batch normalization. +// +// DEPRECATED at GraphDef version 9: Use tf.nn.batch_normalization() +// +// This op is deprecated. Prefer `tf.nn.batch_normalization`. +// +// Arguments: +// t: A 4D input Tensor. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func BatchNormWithGlobalNormalization(scope *Scope, t tf.Output, m tf.Output, v tf.Output, beta tf.Output, gamma tf.Output, variance_epsilon float32, scale_after_normalization bool) (result tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + opspec := tf.OpSpec{ + Type: "BatchNormWithGlobalNormalization", Input: []tf.Input{ - serialized, + t, m, v, beta, gamma, }, Attrs: attrs, } @@ -17110,113 +17354,95 @@ func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (outp return op.Output(0) } -// MapClearAttr is an optional argument to MapClear. -type MapClearAttr func(optionalAttr) - -// MapClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Deprecated. Use TensorArrayReadV3 // -// REQUIRES: value >= 0 -func MapClearCapacity(value int64) MapClearAttr { - return func(m optionalAttr) { - m["capacity"] = value +// DEPRECATED at GraphDef version 26: Use TensorArrayReadV3 +func TensorArrayReadV2(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) { + if scope.Err() != nil { + return } -} - -// MapClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapClearMemoryLimit(value int64) MapClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "TensorArrayReadV2", + Input: []tf.Input{ + handle, index, flow_in, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapClearContainer(value string) MapClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} +// QuantizedMulAttr is an optional argument to QuantizedMul. +type QuantizedMulAttr func(optionalAttr) -// MapClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapClearSharedName(value string) MapClearAttr { +// QuantizedMulToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMulToutput(value tf.DataType) QuantizedMulAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["Toutput"] = value } } -// Op removes all elements in the underlying container. +// Returns x * y element-wise, working on quantized buffers. // -// Returns the created operation. -func MapClear(scope *Scope, dtypes []tf.DataType, optional ...MapClearAttr) (o *tf.Operation) { +// Arguments: +// +// +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +// +// *NOTE*: `QuantizedMul` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedMul(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedMulAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapClear", - + Type: "QuantizedMul", + Input: []tf.Input{ + x, y, min_x, max_x, min_y, max_y, + }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeCSVAttr is an optional argument to DecodeCSV. -type DecodeCSVAttr func(optionalAttr) +// QuantizedAddAttr is an optional argument to QuantizedAdd. +type QuantizedAddAttr func(optionalAttr) -// DecodeCSVFieldDelim sets the optional field_delim attribute to value. -// -// value: char delimiter to separate fields in a record. -// If not specified, defaults to "," -func DecodeCSVFieldDelim(value string) DecodeCSVAttr { +// QuantizedAddToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedAddToutput(value tf.DataType) QuantizedAddAttr { return func(m optionalAttr) { - m["field_delim"] = value + m["Toutput"] = value } } -// DecodeCSVUseQuoteDelim sets the optional use_quote_delim attribute to value. +// Returns x + y element-wise, working on quantized buffers. // -// value: If false, treats double quotation marks as regular -// characters inside of the string fields (ignoring RFC 4180, Section 2, -// Bullet 5). -// If not specified, defaults to true -func DecodeCSVUseQuoteDelim(value bool) DecodeCSVAttr { - return func(m optionalAttr) { - m["use_quote_delim"] = value - } -} - -// DecodeCSVNaValue sets the optional na_value attribute to value. +// Arguments: // -// value: Additional string to recognize as NA/NaN. -// If not specified, defaults to "" -func DecodeCSVNaValue(value string) DecodeCSVAttr { - return func(m optionalAttr) { - m["na_value"] = value - } -} - -// Convert CSV records to tensors. Each column maps to one tensor. // -// RFC 4180 format is expected for the CSV records. -// (https://tools.ietf.org/html/rfc4180) -// Note that we allow leading and trailing spaces with int or float field. +// min_x: The float value that the lowest quantized `x` value represents. +// max_x: The float value that the highest quantized `x` value represents. +// min_y: The float value that the lowest quantized `y` value represents. +// max_y: The float value that the highest quantized `y` value represents. // -// Arguments: -// records: Each string is a record/row in the csv and all records should have -// the same format. -// record_defaults: One tensor per column of the input record, with either a -// scalar default value for that column or empty if the column is required. +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. // -// Returns Each tensor will have the same shape as records. -func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) { +// *NOTE*: `QuantizedAdd` supports limited forms of broadcasting. More about +// broadcasting [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x tf.Output, min_y tf.Output, max_y tf.Output, optional ...QuantizedAddAttr) (z tf.Output, min_z tf.Output, max_z tf.Output) { if scope.Err() != nil { return } @@ -17225,84 +17451,117 @@ func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeCSV", + Type: "QuantizedAdd", Input: []tf.Input{ - records, tf.OutputList(record_defaults), + x, y, min_x, max_x, min_y, max_y, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return + return op.Output(0), op.Output(1), op.Output(2) +} + +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) + +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["upper_frequency_limit"] = value } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("DecodeCSV", err) - return +} + +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value } - return output } -// Returns the rank of a tensor. +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. // -// This operation returns an integer representing the rank of `input`. +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. // -// For example: +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. // -// ``` -// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -// # shape of tensor 't' is [2, 2, 3] -// rank(t) ==> 3 -// ``` +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. // -// **Note**: The rank of a tensor is not the same as the rank of a matrix. The rank -// of a tensor is the number of indices required to uniquely select each element -// of the tensor. Rank is also known as "order", "degree", or "ndims." -func Rank(scope *Scope, input tf.Output) (output tf.Output) { +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Rank", + Type: "Mfcc", Input: []tf.Input{ - input, + spectrogram, sample_rate, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes its input available to the next iteration. +// Given a quantized tensor described by (input, input_min, input_max), outputs a +// +// range that covers the actual values present in that tensor. This op is +// typically used to produce the requested_output_min and requested_output_max for +// Requantize. // // Arguments: -// data: The tensor to be made available to the next iteration. // -// Returns The same tensor as `data`. -func NextIteration(scope *Scope, data tf.Output) (output tf.Output) { +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// +// Returns The computed min output.the computed max output. +func RequantizationRange(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output) (output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "NextIteration", + Type: "RequantizationRange", Input: []tf.Input{ - data, + input, input_min, input_max, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // MapPeekAttr is an optional argument to MapPeek. @@ -18911,101 +19170,6 @@ func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output t return op.Output(0) } -// Elementwise computes the bitwise OR of `x` and `y`. -// -// The result will have those bits set, that are set in `x`, `y` or both. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BitwiseOr", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same -// type as `matrix` and shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations -// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` -// in the least squares sense. -// -// We use the following notation for (complex) matrix and right-hand sides -// in the batch: -// -// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), -// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), -// `output`=\\(X \in \mathbb{C}^{n \times k}\\), -// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + -// \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as -// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), -// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable -// when \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is -// sufficiently large. -// -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. -// -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility -// -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolveLs", - Input: []tf.Input{ - matrix, rhs, l2_regularizer, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // SvdAttr is an optional argument to Svd. type SvdAttr func(optionalAttr) @@ -20850,6 +21014,61 @@ func Iterator(scope *Scope, shared_name string, container string, output_types [ return op.Output(0) } +// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. +type CropAndResizeGradImageAttr func(optionalAttr) + +// CropAndResizeGradImageMethod sets the optional method attribute to value. +// +// value: A string specifying the interpolation method. Only 'bilinear' is +// supported for now. +// If not specified, defaults to "bilinear" +func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { + return func(m optionalAttr) { + m["method"] = value + } +} + +// Computes the gradient of the crop_and_resize op wrt the input image tensor. +// +// Arguments: +// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. +// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor +// specifies the coordinates of a box in the `box_ind[i]` image and is specified +// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of +// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the +// `[0, 1]` interval of normalized image height is mapped to +// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in +// which case the sampled crop is an up-down flipped version of the original +// image. The width dimension is treated similarly. Normalized coordinates +// outside the `[0, 1]` range are allowed, in which case we use +// `extrapolation_value` to extrapolate the input image values. +// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. +// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. +// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` +// containing the original image size. Both `image_height` and `image_width` need +// to be positive. +// +// +// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. +func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"T": T} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CropAndResizeGradImage", + Input: []tf.Input{ + grads, boxes, box_ind, image_size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ShuffleDatasetAttr is an optional argument to ShuffleDataset. type ShuffleDatasetAttr func(optionalAttr) @@ -21717,47 +21936,6 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out return op.Output(0) } -// PlaceholderAttr is an optional argument to Placeholder. -type PlaceholderAttr func(optionalAttr) - -// PlaceholderShape sets the optional shape attribute to value. -// -// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the -// shape is unconstrained. -// If not specified, defaults to <unknown_rank:true > -func PlaceholderShape(value tf.Shape) PlaceholderAttr { - return func(m optionalAttr) { - m["shape"] = value - } -} - -// A placeholder op for a value that will be fed into the computation. -// -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. -// -// Arguments: -// dtype: The type of elements in the tensor. -// -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Placeholder", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Creates a dataset that executes a SQL query and emits rows of the result set. // // Arguments: @@ -23339,101 +23517,6 @@ func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } -// CropAndResizeGradImageAttr is an optional argument to CropAndResizeGradImage. -type CropAndResizeGradImageAttr func(optionalAttr) - -// CropAndResizeGradImageMethod sets the optional method attribute to value. -// -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. -// If not specified, defaults to "bilinear" -func CropAndResizeGradImageMethod(value string) CropAndResizeGradImageAttr { - return func(m optionalAttr) { - m["method"] = value - } -} - -// Computes the gradient of the crop_and_resize op wrt the input image tensor. -// -// Arguments: -// grads: A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. -// boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor -// specifies the coordinates of a box in the `box_ind[i]` image and is specified -// in normalized coordinates `[y1, x1, y2, x2]`. A normalized coordinate value of -// `y` is mapped to the image coordinate at `y * (image_height - 1)`, so as the -// `[0, 1]` interval of normalized image height is mapped to -// `[0, image_height - 1] in image height coordinates. We do allow y1 > y2, in -// which case the sampled crop is an up-down flipped version of the original -// image. The width dimension is treated similarly. Normalized coordinates -// outside the `[0, 1]` range are allowed, in which case we use -// `extrapolation_value` to extrapolate the input image values. -// box_ind: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, batch)`. -// The value of `box_ind[i]` specifies the image that the `i`-th box refers to. -// image_size: A 1-D tensor with value `[batch, image_height, image_width, depth]` -// containing the original image size. Both `image_height` and `image_width` need -// to be positive. -// -// -// Returns A 4-D tensor of shape `[batch, image_height, image_width, depth]`. -func CropAndResizeGradImage(scope *Scope, grads tf.Output, boxes tf.Output, box_ind tf.Output, image_size tf.Output, T tf.DataType, optional ...CropAndResizeGradImageAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"T": T} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CropAndResizeGradImage", - Input: []tf.Input{ - grads, boxes, box_ind, image_size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Reads and outputs the entire contents of the input filename. -func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReadFile", - Input: []tf.Input{ - filename, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Forwards the value of an available tensor from `inputs` to `output`. // // `Merge` waits for at least one of the tensors in `inputs` to become available. @@ -27804,86 +27887,3 @@ func BroadcastGradientArgs(scope *Scope, s0 tf.Output, s1 tf.Output) (r0 tf.Outp op := scope.AddOperation(opspec) return op.Output(0), op.Output(1) } - -// Pads a tensor with mirrored values. -// -// This operation pads a `input` with mirrored values according to the `paddings` -// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many values to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many values to add after the contents of `input` -// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater -// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true -// (if false, respectively). -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 2, 3], [4, 5, 6]]. -// # 'paddings' is [[1, 1]], [2, 2]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] -// [2, 1, 1, 2, 3, 3, 2] -// [5, 4, 4, 5, 6, 6, 5] -// [5, 4, 4, 5, 6, 6, 5]] -// ``` -// -// Arguments: -// input: The input tensor to be padded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions -// do not include the borders, while in symmetric mode the padded regions -// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` -// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and -// it is `[1, 2, 3, 3, 2]` in symmetric mode. -// -// Returns The padded tensor. -func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode} - opspec := tf.OpSpec{ - Type: "MirrorPad", - Input: []tf.Input{ - input, paddings, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// A placeholder op for a value that will be fed into the computation. -// -// DEPRECATED at GraphDef version 23: Placeholder now behaves the same as PlaceholderV2. -// -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. -// -// Arguments: -// dtype: The type of elements in the tensor. -// shape: The shape of the tensor. The shape can be any partially-specified -// shape. To be unconstrained, pass in a shape with unknown rank. -// -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape} - opspec := tf.OpSpec{ - Type: "PlaceholderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7bd05bb6e0..94a6355e9a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2884,9 +2884,11 @@ py_library( ":client", ":control_flow_ops", ":data_flow_ops", + ":device", ":errors", ":framework", ":framework_for_generated_wrappers", + ":framework_ops", ":gradients", ":init_ops", ":io_ops", @@ -2911,6 +2913,7 @@ py_library( ":variable_scope", ":variables", "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 701f68b8f7..55ba509065 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1013,12 +1013,13 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); const tensorflow::Tensor* tensor = nullptr; - const tensorflow::Status status = t->Tensor(&tensor); + const tensorflow::Status status = t->handle->Tensor(&tensor); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - return tensorflow::eager::TapeTensor{id, t->dtype, + return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensorflow::TensorShape({})}; } else { - return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()}; + return tensorflow::eager::TapeTensor{id, t->handle->dtype, + tensor->shape()}; } } tensorflow::int64 id = FastTensorId(tensor); diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index c9635a9c27..bb033d3495 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -887,11 +887,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( Raises: ValueError: If `thresholds` contains a value outside of `(0, 1)`. ValueError: If `loss_reduction` is invalid. + TypeError: if `label_vocabulary` has invalid type. """ thresholds = tuple(thresholds) if thresholds else tuple() if label_vocabulary is not None and not isinstance(label_vocabulary, (list, tuple)): - raise ValueError( + raise TypeError( 'label_vocabulary should be a list or tuple. Given type: {}'.format( type(label_vocabulary))) diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 4bb9941bb7..391b17720c 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -737,7 +737,9 @@ def import_scoped_meta_graph(meta_graph_or_file, import_scope or "", mark_as_used=False) importer.import_graph_def( - input_graph_def, name=(import_scope or ""), input_map=input_map, + input_graph_def, + name=(import_scope or scope_to_prepend_to_names), + input_map=input_map, producer_op_list=producer_op_list) # Restores all the other collections. diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 21963d0bee..5d5fb037fc 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -537,6 +537,21 @@ class ScopedMetaGraphTest(test.TestCase): self.assertEqual(list(imported_variables.values())[0].name, "foo/bar/myvar:0") + def testScopedImportUnderNameScopeNoVarScope(self): + graph = ops.Graph() + with graph.as_default(): + variables.Variable(initial_value=1.0, trainable=True, name="myvar") + meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph) + + graph = ops.Graph() + with graph.as_default(): + with ops.name_scope("foo"): + imported_variables = meta_graph.import_scoped_meta_graph( + meta_graph_def) + self.assertEqual(len(imported_variables), 1) + self.assertEqual(list(imported_variables.values())[0].name, + "foo/myvar:0") + def testImportsUsingSameScopeName(self): with ops.Graph().as_default(): variables.Variable(0, name="v") diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 93edaa0cf0..e579289a8d 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5863,6 +5863,9 @@ def strip_name_scope(name, export_scope): is None. """ if export_scope: + if export_scope[-1] == "/": + export_scope = export_scope[:-1] + try: # Strips export_scope/, export_scope///, # ^export_scope/, loc:@export_scope/. @@ -5888,6 +5891,9 @@ def prepend_name_scope(name, import_scope): is None. """ if import_scope: + if import_scope[-1] == "/": + import_scope = import_scope[:-1] + try: str_to_replace = r"([\^]|loc:@|^)(.*)" return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", diff --git a/tensorflow/python/framework/test_file_system.cc b/tensorflow/python/framework/test_file_system.cc index 094ea6f658..6e9915adbb 100644 --- a/tensorflow/python/framework/test_file_system.cc +++ b/tensorflow/python/framework/test_file_system.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/null_file_system.h" namespace tensorflow { diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 711106d2db..16033e9b8f 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -402,11 +402,10 @@ py_test( py_test( name = "convolutional_recurrent_test", - size = "medium", + size = "large", srcs = ["_impl/keras/layers/convolutional_recurrent_test.py"], shard_count = 2, srcs_version = "PY2AND3", - tags = ["noasan"], # times out b/63678675 deps = [ ":keras", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 2dc993f811..742564f9bf 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -103,6 +103,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(False, name="bool_test") self.assertAllEqual(bool(v), False) + def testFetchHandle(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1], name="foo") + self.assertGreater(len(handle.eval()), 0) + def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -179,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSub(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + [[2]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMul(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + [[5]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDiv(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMin(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMax(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + [[3]], + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterAddScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_add(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterSubScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_sub(handle, [0], + constant_op.constant( + 2, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[-1]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMulScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_mul(handle, [0], + constant_op.constant( + 5, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[5]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterDivScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_div(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[2]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMinScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_min(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def testScatterMaxScalar(self): + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[6]], + dtype=dtypes.int32))) + self.evaluate( + resource_variable_ops.resource_scatter_max(handle, [0], + constant_op.constant( + 3, + dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[6]]) + def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.string, shape=[1, 1]) @@ -190,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + def testScatterUpdateStringScalar(self): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.string, shape=[1, 1]) + self.evaluate( + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [["a"]], + dtype=dtypes.string))) + self.evaluate( + resource_variable_ops.resource_scatter_update(handle, [0], + constant_op.constant( + "b", + dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual( + compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b")) + # TODO(alive): get this to work in Eager mode. def testGPU(self): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index 7cdf11d884..c70a4ffce7 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates): ref[indx] += updates[i] +def _NumpyAddScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] += update + + def _NumpySub(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] -= updates[i] +def _NumpySubScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] -= update + + def _NumpyMul(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] *= updates[i] +def _NumpyMulScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] *= update + + def _NumpyDiv(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] /= updates[i] +def _NumpyDivScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] /= update + + +def _NumpyMin(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], updates[i]) + + +def _NumpyMinScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.minimum(ref[indx], update) + + +def _NumpyMax(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], updates[i]) + + +def _NumpyMaxScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = np.maximum(ref[indx], update) + + def _NumpyUpdate(ref, indices, updates): for i, indx in np.ndenumerate(indices): ref[indx] = updates[i] +def _NumpyUpdateScalar(ref, indices, update): + for _, indx in np.ndenumerate(indices): + ref[indx] = update + + _TF_OPS_TO_NUMPY = { state_ops.scatter_update: _NumpyUpdate, state_ops.scatter_add: _NumpyAdd, state_ops.scatter_sub: _NumpySub, state_ops.scatter_mul: _NumpyMul, state_ops.scatter_div: _NumpyDiv, + state_ops.scatter_min: _NumpyMin, + state_ops.scatter_max: _NumpyMax, +} + +_TF_OPS_TO_NUMPY_SCALAR = { + state_ops.scatter_update: _NumpyUpdateScalar, + state_ops.scatter_add: _NumpyAddScalar, + state_ops.scatter_sub: _NumpySubScalar, + state_ops.scatter_mul: _NumpyMulScalar, + state_ops.scatter_div: _NumpyDivScalar, + state_ops.scatter_min: _NumpyMinScalar, + state_ops.scatter_max: _NumpyMaxScalar, } class ScatterTest(test.TestCase): - def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False): + def _VariableRankTest(self, + tf_scatter, + vtype, + itype, + repeat_indices=False, + updates_are_scalar=False): np.random.seed(8) with self.test_session(use_gpu=True): for indices_shape in (), (2,), (3, 7), (3, 4, 7): @@ -89,8 +151,11 @@ class ScatterTest(test.TestCase): indices[np.random.randint(size // 2)]) np.random.shuffle(indices) indices = indices.reshape(indices_shape) - updates = _AsType( - np.random.randn(*(indices_shape + extra_shape)), vtype) + if updates_are_scalar: + updates = _AsType(np.random.randn(), vtype) + else: + updates = _AsType( + np.random.randn(*(indices_shape + extra_shape)), vtype) # Clips small values to avoid division by zero. def clip_small_values(x): @@ -101,7 +166,10 @@ class ScatterTest(test.TestCase): # Scatter via numpy new = old.copy() - np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] + if updates_are_scalar: + np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter] + else: + np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] np_scatter(new, indices, updates) # Scatter via tensorflow ref = variables.Variable(old) @@ -109,25 +177,35 @@ class ScatterTest(test.TestCase): tf_scatter(ref, indices, updates).eval() self.assertAllClose(ref.eval(), new) - def _VariableRankTests(self, tf_scatter, repeat_indices=False): + def _VariableRankTests(self, + tf_scatter, + repeat_indices=False, + updates_are_scalar=False): for vtype in (np.float32, np.float64): for itype in (np.int32, np.int64): - self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices) + self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices, + updates_are_scalar) def testVariableRankUpdate(self): - self._VariableRankTests(state_ops.scatter_update) + self._VariableRankTests(state_ops.scatter_update, False) def testVariableRankAdd(self): - self._VariableRankTests(state_ops.scatter_add) + self._VariableRankTests(state_ops.scatter_add, False) def testVariableRankSub(self): - self._VariableRankTests(state_ops.scatter_sub) + self._VariableRankTests(state_ops.scatter_sub, False) def testVariableRankMul(self): - self._VariableRankTests(state_ops.scatter_mul) + self._VariableRankTests(state_ops.scatter_mul, False) def testVariableRankDiv(self): - self._VariableRankTests(state_ops.scatter_div) + self._VariableRankTests(state_ops.scatter_div, False) + + def testVariableRankMin(self): + self._VariableRankTests(state_ops.scatter_min, False) + + def testVariableRankMax(self): + self._VariableRankTests(state_ops.scatter_max, False) def testRepeatIndicesAdd(self): self._VariableRankTests(state_ops.scatter_add, True) @@ -141,6 +219,51 @@ class ScatterTest(test.TestCase): def testRepeatIndicesDiv(self): self._VariableRankTests(state_ops.scatter_div, True) + def testRepeatIndicesMin(self): + self._VariableRankTests(state_ops.scatter_min, True) + + def testRepeatIndicesMax(self): + self._VariableRankTests(state_ops.scatter_max, True) + + def testVariableRankUpdateScalar(self): + self._VariableRankTests(state_ops.scatter_update, False, True) + + def testVariableRankAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, False, True) + + def testVariableRankSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, False, True) + + def testVariableRankMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, False, True) + + def testVariableRankDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, False, True) + + def testVariableRankMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, False, True) + + def testVariableRankMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, False, True) + + def testRepeatIndicesAddScalar(self): + self._VariableRankTests(state_ops.scatter_add, True, True) + + def testRepeatIndicesSubScalar(self): + self._VariableRankTests(state_ops.scatter_sub, True, True) + + def testRepeatIndicesMulScalar(self): + self._VariableRankTests(state_ops.scatter_mul, True, True) + + def testRepeatIndicesDivScalar(self): + self._VariableRankTests(state_ops.scatter_div, True, True) + + def testRepeatIndicesMinScalar(self): + self._VariableRankTests(state_ops.scatter_min, True, True) + + def testRepeatIndicesMaxScalar(self): + self._VariableRankTests(state_ops.scatter_max, True, True) + def testBooleanScatterUpdate(self): if not test.is_gpu_available(): with self.test_session(use_gpu=False) as session: diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 994af69386..a07e305ffb 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -267,7 +267,9 @@ gtl::InlinedVector<npy_intp, 4> GetPyArrayDimensionsForTensor( const int ndims = TF_NumDims(tensor); gtl::InlinedVector<npy_intp, 4> dims(ndims); if (TF_TensorType(tensor) == TF_RESOURCE) { - dims[0] = TF_TensorByteSize(tensor); + CHECK_EQ(ndims, 0) + << "Fetching of non-scalar resource tensors is not supported."; + dims.push_back(TF_TensorByteSize(tensor)); *nelems = dims[0]; } else { *nelems = 1; diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 02eafd42b3..22317a348c 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -166,7 +166,7 @@ bool IsSingleNone(PyObject* obj) { // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, const Tensor** output_tensor) { - return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor); + return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); } // Calls the registered py function through the trampoline. diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 4071e50e81..0866fa8b0b 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -593,7 +593,7 @@ class Distribution(_BaseDistribution): Returns: batch_shape: `TensorShape`, possibly unknown. """ - return self._batch_shape() + return tensor_shape.as_shape(self._batch_shape()) def _event_shape_tensor(self): raise NotImplementedError("event_shape_tensor is not implemented") @@ -626,7 +626,7 @@ class Distribution(_BaseDistribution): Returns: event_shape: `TensorShape`, possibly unknown. """ - return self._event_shape() + return tensor_shape.as_shape(self._event_shape()) def is_scalar_event(self, name="is_scalar_event"): """Indicates that `event_shape == []`. @@ -1105,6 +1105,34 @@ class Distribution(_BaseDistribution): with self._name_scope(name): return self._kl_divergence(other) + def __str__(self): + return ("tf.distributions.{type_name}(" + "\"{self_name}\"" + "{maybe_batch_shape}" + "{maybe_event_shape}" + ", dtype={dtype})".format( + type_name=type(self).__name__, + self_name=self.name, + maybe_batch_shape=(", batch_shape={}".format(self.batch_shape) + if self.batch_shape.ndims is not None + else ""), + maybe_event_shape=(", event_shape={}".format(self.event_shape) + if self.event_shape.ndims is not None + else ""), + dtype=self.dtype.name)) + + def __repr__(self): + return ("<tf.distributions.{type_name} " + "'{self_name}'" + " batch_shape={batch_shape}" + " event_shape={event_shape}" + " dtype={dtype}>".format( + type_name=type(self).__name__, + self_name=self.name, + batch_shape=self.batch_shape, + event_shape=self.event_shape, + dtype=self.dtype.name)) + @contextlib.contextmanager def _name_scope(self, name=None, values=None): """Helper function to standardize op scope.""" diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 230b7ef937..e90ff0746a 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add from tensorflow.python.ops.state_ops import scatter_div from tensorflow.python.ops.state_ops import scatter_mul from tensorflow.python.ops.state_ops import scatter_sub +from tensorflow.python.ops.state_ops import scatter_min +from tensorflow.python.ops.state_ops import scatter_max from tensorflow.python.ops.state_ops import scatter_update from tensorflow.python.ops.state_ops import scatter_nd_add from tensorflow.python.ops.state_ops import scatter_nd_sub diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index c3ad5831b4..01fc3182bc 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -63,6 +63,8 @@ @@scatter_nd_update @@scatter_sub @@scatter_update +@@scatter_min +@@scatter_max @@sparse_mask @@tables_initializer @@trainable_variables diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py new file mode 100644 index 0000000000..f1137e80ab --- /dev/null +++ b/tensorflow/python/training/device_util.py @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Device-related support functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops + + +def canonicalize(d): + d = tf_device.DeviceSpec.from_string(d) + assert d.device_type is None or d.device_type == d.device_type.upper(), ( + "Device type '%s' must be all-caps." % (d.device_type,)) + # Fill in missing device fields using defaults. + result = tf_device.DeviceSpec( + job="localhost", replica=0, task=0, device_type="CPU", device_index=0) + result.merge_from(d) + return result.to_string() + + +class _FakeNodeDef(object): + """A fake NodeDef for _FakeOperation.""" + + def __init__(self): + self.op = "" + self.name = "" + + +class _FakeOperation(object): + """A fake Operation object to pass to device functions.""" + + def __init__(self): + self.device = "" + self.type = "" + self.name = "" + self.node_def = _FakeNodeDef() + + def _set_device(self, device): + self.device = ops._device_string(device) # pylint: disable=protected-access + + +def current(): + """Return a string (not canonicalized) for the current device.""" + # TODO(josh11b): Work out how this function interacts with ops.colocate_with. + ctx = context.context() + if ctx.executing_eagerly(): + d = ctx.device_name + else: + op = _FakeOperation() + ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access + d = op.device + return d diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 03e3e0857f..ab5e6590e0 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -3157,12 +3157,18 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { mutex_lock lock{dnn_handle_mutex_}; + cudnnStatus_t status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); + } + float beta = 0.0f; ScopedTensorDescriptor input_tensor_desc( parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout())); ScopedTensorDescriptor output_tensor_desc( parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout())); - cudnnStatus_t status = wrap::cudnnTransformTensor( + status = wrap::cudnnTransformTensor( parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(), input_data.opaque(), &beta, output_tensor_desc.handle(), output_data->opaque()); diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 55b82dd765..937044aece 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1689,6 +1689,14 @@ tf_module { argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { + name: "scatter_max" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_min" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { name: "scatter_mul" argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh index 564c5aa148..d81793efe0 100755 --- a/tensorflow/tools/ci_build/builds/android.sh +++ b/tensorflow/tools/ci_build/builds/android.sh @@ -29,7 +29,8 @@ echo "========== TensorFlow Demo Build Test ==========" # Enable sandboxing so that zip archives don't get incorrectly packaged # in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334) # TODO(gunan): remove extra flags once sandboxing is enabled for all builds. -bazel --bazelrc=/dev/null build -c opt --fat_apk_cpu=x86_64 \ +bazel --bazelrc=/dev/null build \ + --compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/examples/android:tensorflow_demo diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh index 9d449241e8..41dc66dd54 100755 --- a/tensorflow/tools/ci_build/builds/android_full.sh +++ b/tensorflow/tools/ci_build/builds/android_full.sh @@ -40,7 +40,8 @@ rm -rf ${AAR_LIB_TMP} for CPU in ${CPUS//,/ } do echo "========== Building native libs for Android ${CPU} ==========" - bazel build -c opt --config=monolithic --cpu=${CPU} \ + bazel build --config=monolithic --cpu=${CPU} \ + --compilation_mode=opt --cxxopt=-std=c++11 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ //tensorflow/core:android_tensorflow_lib \ @@ -62,7 +63,8 @@ done # in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334) # TODO(gunan): remove extra flags once sandboxing is enabled for all builds. echo "========== Building TensorFlow Android Jar and Demo ==========" -bazel --bazelrc=/dev/null build -c opt --config=monolithic --fat_apk_cpu=${CPUS} \ +bazel --bazelrc=/dev/null build --config=monolithic --fat_apk_cpu=${CPUS} \ + --compilation_mode=opt --cxxopt=-std=c++11 \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/contrib/android:android_tensorflow_inference_java \ //tensorflow/contrib/android:android_tensorflow_inference_java.aar \ diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 34dd419f15..d22a465376 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -211,6 +211,7 @@ def _get_default_do_not_descend_map(): 'tf': ['cli', 'lib', 'wrappers'], 'tf.contrib': [ 'compiler', + 'distribute', 'grid_rnn', # Block contrib.keras to de-clutter the docs 'keras', diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 8a1c7db2ea..f8fb6ecb0c 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -51,6 +51,9 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data +layers = tf.keras.layers + + def _bundle(lstm_iter): """Concatenate a list of Tensors along 1st axis and split result into two. @@ -78,17 +81,16 @@ def _unbundle(state): return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0) -class Reducer(tfe.Network): +# pylint: disable=not-callable +class Reducer(tf.keras.Model): """A module that applies reduce operation on left and right vectors.""" def __init__(self, size, tracker_size=None): super(Reducer, self).__init__() - self.left = self.track_layer(tf.layers.Dense(5 * size, activation=None)) - self.right = self.track_layer( - tf.layers.Dense(5 * size, activation=None, use_bias=False)) + self.left = layers.Dense(5 * size, activation=None) + self.right = layers.Dense(5 * size, activation=None, use_bias=False) if tracker_size is not None: - self.track = self.track_layer( - tf.layers.Dense(5 * size, activation=None, use_bias=False)) + self.track = layers.Dense(5 * size, activation=None, use_bias=False) else: self.track = None @@ -123,7 +125,7 @@ class Reducer(tfe.Network): return h, c -class Tracker(tfe.Network): +class Tracker(tf.keras.Model): """A module that tracks the history of the sentence with an LSTM.""" def __init__(self, tracker_size, predict): @@ -134,10 +136,10 @@ class Tracker(tfe.Network): predict: (`bool`) Whether prediction mode is enabled. """ super(Tracker, self).__init__() - self._rnn = self.track_layer(tf.nn.rnn_cell.LSTMCell(tracker_size)) + self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size) self._state_size = tracker_size if predict: - self._transition = self.track_layer(tf.layers.Dense(4)) + self._transition = layers.Dense(4) else: self._transition = None @@ -182,7 +184,7 @@ class Tracker(tfe.Network): return unbundled, None -class SPINN(tfe.Network): +class SPINN(tf.keras.Model): """Stack-augmented Parser-Interpreter Neural Network. See https://arxiv.org/abs/1603.06021 for more details. @@ -204,9 +206,9 @@ class SPINN(tfe.Network): """ super(SPINN, self).__init__() self.config = config - self.reducer = self.track_layer(Reducer(config.d_hidden, config.d_tracker)) + self.reducer = Reducer(config.d_hidden, config.d_tracker) if config.d_tracker is not None: - self.tracker = self.track_layer(Tracker(config.d_tracker, config.predict)) + self.tracker = Tracker(config.d_tracker, config.predict) else: self.tracker = None @@ -248,7 +250,7 @@ class SPINN(tfe.Network): trans = transitions[i] if self.tracker: # Invoke tracker to obtain the current tracker states for the sentences. - tracker_states, trans_hypothesis = self.tracker(buffers, stacks) + tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks) if trans_hypothesis: trans = tf.argmax(trans_hypothesis, axis=-1) else: @@ -264,7 +266,8 @@ class SPINN(tfe.Network): trackings.append(tracking) if rights: - reducer_output = self.reducer(lefts, rights, trackings) + reducer_output = self.reducer( + lefts, right_in=rights, tracking=trackings) reduced = iter(reducer_output) for transition, stack in zip(trans, stacks): @@ -273,7 +276,27 @@ class SPINN(tfe.Network): return _bundle([stack.pop() for stack in stacks])[0] -class SNLIClassifier(tfe.Network): +class Perceptron(tf.keras.Model): + """One layer of the SNLIClassifier multi-layer perceptron.""" + + def __init__(self, dimension, dropout_rate, previous_layer): + """Configure the Perceptron.""" + super(Perceptron, self).__init__() + self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu) + self.batchnorm = layers.BatchNormalization() + self.dropout = layers.Dropout(rate=dropout_rate) + self.previous_layer = previous_layer + + def call(self, x, training): + """Run previous Perceptron layers, then this one.""" + x = self.previous_layer(x, training=training) + x = self.dense(x) + x = self.batchnorm(x, training=training) + x = self.dropout(x, training=training) + return x + + +class SNLIClassifier(tf.keras.Model): """SNLI Classifier Model. A model aimed at solving the SNLI (Standford Natural Language Inference) @@ -304,29 +327,24 @@ class SNLIClassifier(tfe.Network): self.config = config self.embed = tf.constant(embed) - self.projection = self.track_layer(tf.layers.Dense(config.d_proj)) - self.embed_bn = self.track_layer(tf.layers.BatchNormalization()) - self.embed_dropout = self.track_layer( - tf.layers.Dropout(rate=config.embed_dropout)) - self.encoder = self.track_layer(SPINN(config)) - - self.feature_bn = self.track_layer(tf.layers.BatchNormalization()) - self.feature_dropout = self.track_layer( - tf.layers.Dropout(rate=config.mlp_dropout)) - - self.mlp_dense = [] - self.mlp_bn = [] - self.mlp_dropout = [] - for _ in xrange(config.n_mlp_layers): - self.mlp_dense.append(self.track_layer(tf.layers.Dense(config.d_mlp))) - self.mlp_bn.append( - self.track_layer(tf.layers.BatchNormalization())) - self.mlp_dropout.append( - self.track_layer(tf.layers.Dropout(rate=config.mlp_dropout))) - self.mlp_output = self.track_layer(tf.layers.Dense( + self.projection = layers.Dense(config.d_proj) + self.embed_bn = layers.BatchNormalization() + self.embed_dropout = layers.Dropout(rate=config.embed_dropout) + self.encoder = SPINN(config) + + self.feature_bn = layers.BatchNormalization() + self.feature_dropout = layers.Dropout(rate=config.mlp_dropout) + + current_mlp = lambda result, training: result + for _ in range(config.n_mlp_layers): + current_mlp = Perceptron(dimension=config.d_mlp, + dropout_rate=config.mlp_dropout, + previous_layer=current_mlp) + self.mlp = current_mlp + self.mlp_output = layers.Dense( config.d_out, kernel_initializer=tf.random_uniform_initializer(minval=-5e-3, - maxval=5e-3))) + maxval=5e-3)) def call(self, premise, @@ -370,10 +388,10 @@ class SNLIClassifier(tfe.Network): # Run the batch-normalized and dropout-processed word vectors through the # SPINN encoder. - premise = self.encoder(premise_embed, premise_transition, - training=training) - hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, - training=training) + premise = self.encoder( + premise_embed, transitions=premise_transition, training=training) + hypothesis = self.encoder( + hypothesis_embed, transitions=hypothesis_transition, training=training) # Combine encoder outputs for premises and hypotheses into logits. # Then apply batch normalization and dropuout on the logits. @@ -383,15 +401,12 @@ class SNLIClassifier(tfe.Network): self.feature_bn(logits, training=training), training=training) # Apply the multi-layer perceptron on the logits. - for dense, bn, dropout in zip( - self.mlp_dense, self.mlp_bn, self.mlp_dropout): - logits = tf.nn.elu(dense(logits)) - logits = dropout(bn(logits, training=training), training=training) + logits = self.mlp(logits, training=training) logits = self.mlp_output(logits) return logits -class SNLIClassifierTrainer(object): +class SNLIClassifierTrainer(tfe.Checkpointable): """A class that coordinates the training of an SNLIClassifier.""" def __init__(self, snli_classifier, lr): @@ -450,10 +465,11 @@ class SNLIClassifierTrainer(object): """ with tfe.GradientTape() as tape: tape.watch(self._model.variables) + # TODO(allenl): Allow passing Layer inputs as position arguments. logits = self._model(premise, - premise_transition, - hypothesis, - hypothesis_transition, + premise_transition=premise_transition, + hypothesis=hypothesis, + hypothesis_transition=hypothesis_transition, training=True) loss = self.loss(labels, logits) gradients = tape.gradient(loss, self._model.variables) @@ -517,7 +533,9 @@ def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = trainer.model( + prem, premise_transition=prem_trans, hypothesis=hypo, + hypothesis_transition=hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -609,29 +627,30 @@ def train_or_infer_spinn(embed, with tf.device(device), \ summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - model = SNLIClassifier(config, embed) - global_step = tf.train.get_or_create_global_step() - trainer = SNLIClassifierTrainer(model, config.lr) + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + checkpoint = tfe.Checkpoint(trainer=trainer, global_step=global_step) + checkpoint.restore(tf.train.latest_checkpoint(config.logdir)) if inference_sentence_pair: # Inference mode. - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - prem, prem_trans = inference_sentence_pair[0] - hypo, hypo_trans = inference_sentence_pair[1] - hypo_trans = inference_sentence_pair[1][1] - inference_logits = model( # pylint: disable=not-callable - tf.constant(prem), tf.constant(prem_trans), - tf.constant(hypo), tf.constant(hypo_trans), training=False) - inference_logits = inference_logits[0][1:] - max_index = tf.argmax(inference_logits) - print("\nInference logits:") - for i, (label, logit) in enumerate( - zip(data.POSSIBLE_LABELS, inference_logits)): - winner_tag = " (winner)" if max_index == i else "" - print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) + prem, prem_trans = inference_sentence_pair[0] + hypo, hypo_trans = inference_sentence_pair[1] + hypo_trans = inference_sentence_pair[1][1] + inference_logits = model( + tf.constant(prem), + premise_transition=tf.constant(prem_trans), + hypothesis=tf.constant(hypo), + hypothesis_transition=tf.constant(hypo_trans), + training=False) + inference_logits = inference_logits[0][1:] + max_index = tf.argmax(inference_logits) + print("\nInference logits:") + for i, (label, logit) in enumerate( + zip(data.POSSIBLE_LABELS, inference_logits)): + winner_tag = " (winner)" if max_index == i else "" + print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) return inference_logits train_len = train_data.num_batches(config.batch_size) @@ -650,20 +669,15 @@ def train_or_infer_spinn(embed, # remain on CPU. Same in _evaluate_on_dataset(). iterations += 1 - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)): - batch_train_loss, batch_train_logits = trainer.train_batch( - label, prem, prem_trans, hypo, hypo_trans) + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) batch_size = tf.shape(label)[0] mean_loss(batch_train_loss.numpy(), weights=batch_size.gpu() if use_gpu else batch_size) accuracy(tf.argmax(batch_train_logits, axis=1), label) if iterations % config.save_every == 0: - all_variables = trainer.variables + [global_step] - saver = tfe.Saver(all_variables) - saver.save(os.path.join(config.logdir, "ckpt"), - global_step=global_step) + checkpoint.save(os.path.join(config.logdir, "ckpt")) if iterations % config.dev_every == 0: dev_loss, dev_frac_correct = _evaluate_on_dataset( |