diff options
252 files changed, 6480 insertions, 3238 deletions
@@ -27,7 +27,7 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see +tracking requests and bugs. So please see [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** diff --git a/RELEASE.md b/RELEASE.md index 0fad3b5d41..0720a8c639 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -96,6 +96,27 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, ç”°ä¼ * Starting from 1.6 release, our prebuilt binaries will use AVX instructions. This may break TF on older CPUs. +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + ## Major Features And Improvements * [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) preview version is now available. @@ -633,7 +654,7 @@ answered questions, and were part of inspiring discussions. * Fixed LIBXSMM integration. * Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type. * Improve implicit broadcasting lowering. -* Improving stability of GCS/Bigquery clients by a faster retrying of stale transmissions. +* Improving stability of GCS/BigQuery clients by a faster retrying of stale transmissions. * Remove OpKernelConstruction::op_def() as part of minimizing proto dependencies. * VectorLaplaceDiag distribution added. * Android demo no longer requires libtensorflow_demo.so to run (libtensorflow_inference.so still required) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8a04953d4c..dc995d231d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -543,7 +543,6 @@ filegroup( "//tensorflow/contrib/model_pruning:all_files", "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", "//tensorflow/contrib/nccl:all_files", - "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index c46cb32aa4..314cbc657c 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -135,6 +135,10 @@ tf_cuda_library( testonly = 1, srcs = ["c_test_util.cc"], hdrs = ["c_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], deps = [ ":c_api", "//tensorflow/core:lib", diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 37439ff0be..3c1d5b5bf8 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -124,8 +124,9 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } -void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, - const char* name, TF_Operation** op, bool check) { +void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, + bool check) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); @@ -139,14 +140,14 @@ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, true); + AddOpHelper(l, r, graph, s, name, &op, true); return op; } TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, false); + AddOpHelper(l, r, graph, s, name, &op, false); return op; } @@ -160,6 +161,26 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, return TF_FinishOperation(desc, s); } +void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op, bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name); + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + BinaryOpHelper("Min", l, r, graph, s, name, &op, true); + return op; +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 6acc2fec00..77520be010 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -69,6 +69,9 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name = "add"); +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "min"); + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name = "neg"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3505f70dc1..e55cb672e9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -72,6 +72,7 @@ tf_cuda_cc_test( ], deps = [ ":c_api", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3a6d2ce45b..9cd1accde9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" @@ -47,19 +48,23 @@ using tensorflow::int64; using tensorflow::string; namespace { -bool IsCPU(tensorflow::Device* d) { +bool IsCPU(const tensorflow::Device* d) { return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; } -bool IsXLA(tensorflow::Device* d) { +bool IsXLA(const tensorflow::Device* d) { if (d == nullptr) return false; const auto& device_type = d->attributes().device_type(); return device_type.find("XLA") != std::string::npos; } -string DeviceName(tensorflow::Device* d) { +string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +std::atomic_int_fast64_t func_id_generator(0); +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace extern "C" { @@ -281,6 +286,14 @@ const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { return device->name().c_str(); } +void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { + op->use_xla = enable; +#ifndef TENSORFLOW_EAGER_USE_XLA + LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " + "built with XLA support."; +#endif // TENSORFLOW_EAGER_USE_XLA +} + void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { // Questionable heuristic ... // @@ -523,6 +536,228 @@ tensorflow::Status ValidateInputTypeAndPlacement( } return tensorflow::Status::OK(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +// Synthesizes and returns a wrapper function over `op`, which must be a +// primitive op (e.g. matmul). +// +// The wrapper function conforms to the function signature expected by +// _XlaLaunchOp, with input params ordered by <constants, (variable) args and +// resources>. For example, if the op has input params <Const1, Arg2, Const3, +// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5, +// Resource4> as the input params to the synthesized function. +// +// It populates `const_input_types`, `arg_input_types` and +// `op_input_to_func_input` based on the reordering results, that the caller can +// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets +// `status` accordingly. +const tensorflow::FunctionDef* OpToFunction( + TFE_Op* op, std::vector<TF_DataType>* const_input_types, + std::vector<TF_DataType>* arg_input_types, + tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input, + TF_Status* status) { + DCHECK(!op->is_function()); + + tensorflow::FunctionDef fdef; + + // Get the OpDef of the op we are trying to encapsulate. + TFE_Context* ctx = op->ctx; + const tensorflow::OpRegistrationData* op_data; + { + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + if (!status->status.ok()) { + return nullptr; + } + } + const tensorflow::OpDef& op_def = op_data->op_def; + + tensorflow::OpDef* signature = fdef.mutable_signature(); + + // Handle constant inputs. + const std::unordered_set<string> const_inputs( + *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); + + // First add place holders for the input args, so that we can refer to them by + // position in the next loop. Also tally up the resource inputs. + int num_resource_inputs = 0; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { + ++num_resource_inputs; + } + signature->add_input_arg(); + } + + // Now we map the input params from `op_def` to `signature`, where the param + // ordering for `signature` is: <constants, args, resources>. + int const_index = 0; + int arg_index = const_inputs.size(); + int resource_index = op_def.input_arg_size() - num_resource_inputs; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); + tensorflow::OpDef::ArgDef* func_input_arg = nullptr; + if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { + VLOG(1) << "For const input, mapping op input " << i << " to func input " + << const_index; + (*op_input_to_func_input)[i] = const_index; + func_input_arg = signature->mutable_input_arg(const_index++); + const_input_types->push_back( + static_cast<TF_DataType>(op->inputs[i].dtype())); + } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { + VLOG(1) << "For resource input, mapping op input " << i + << " to func input " << resource_index; + (*op_input_to_func_input)[i] = resource_index; + func_input_arg = signature->mutable_input_arg(resource_index++); + } else { + VLOG(1) << "For arg input, mapping op input " << i << " to func input " + << arg_index; + (*op_input_to_func_input)[i] = arg_index; + func_input_arg = signature->mutable_input_arg(arg_index++); + arg_input_types->push_back( + static_cast<TF_DataType>(op->inputs[i].dtype())); + } + + func_input_arg->set_name(op_input_arg.name()); + func_input_arg->set_type(op->inputs[i].dtype()); + } + VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); + + // Resources args are at the end of the function input params, and we should + // have iterated over all of them. + DCHECK_EQ(signature->input_arg_size(), resource_index); + + // Make the synthesized function's name unique. + signature->set_name(tensorflow::strings::StrCat( + op_def.name(), func_id_generator.fetch_add(1))); + + // Add the node def and set its input names to match op_def's names. + const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); + DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); + *fdef.add_node_def() = ndef; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); + } + VLOG(1) << "Added NodeDef: " << fdef.DebugString(); + + // Fix the output names and set output types. + for (int i = 0; i < op_def.output_arg_size(); ++i) { + tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); + const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); + const string& out_tensor_name = tensorflow::strings::StrCat( + ndef.name(), ":", op_def_arg.name(), ":", 0); + arg->set_name(op_def_arg.name()); + (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; + const string& type_attr = op_def_arg.type_attr(); + if (!type_attr.empty()) { + auto i = ndef.attr().find(type_attr); + if (i == ndef.attr().end()) { + status->status = tensorflow::errors::InvalidArgument( + tensorflow::strings::StrCat("Could not find attr ", type_attr, + " in NodeDef ", ndef.DebugString())); + return nullptr; + } + arg->set_type(i->second.type()); + } + } + VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); + + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(fdef); + if (!status->status.ok()) return nullptr; + const auto ret = ctx->func_lib_def.Find(signature->name()); + DCHECK(ret != nullptr); + return ret; +} + +// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed +// via XLA. +std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) { + VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; + auto launch_op = + std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + if (op->device) { + TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + + const tensorflow::FunctionDef* fdef; + { + tensorflow::tf_shared_lock l(op->ctx->functions_mu); + fdef = op->ctx->func_lib_def.Find(op->name); + } + std::vector<TF_DataType> const_input_types; + std::vector<TF_DataType> arg_input_types; + tensorflow::gtl::FlatMap<int, int> op_input_to_func_input; + if (fdef == nullptr) { + // See if this is a primitive op, and if so create a function for it, so + // that _XlaLaunchOp can access it. + fdef = OpToFunction(op, &const_input_types, &arg_input_types, + &op_input_to_func_input, status); + if (!status->status.ok()) return nullptr; + } else { + // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for + // functions, so we need to find another way to handle constant inputs. + for (int i = const_input_types.size(); + i < fdef->signature().input_arg_size(); ++i) { + VLOG(1) << "Adding Targs from input arg " << i; + const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); + arg_input_types.push_back(static_cast<TF_DataType>(arg.type())); + } + } + DCHECK(fdef != nullptr); + + // Copy inputs and their devices. + // Since input param reordering may have occurred between `op` and `launch_op` + // via `op_input_to_func_input`, adjust the actual inputs accordingly. + launch_op->inputs = op->inputs; + launch_op->input_devices = op->input_devices; + if (!op_input_to_func_input.empty()) { + DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); + if (!op->input_devices.empty()) { + DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); + } + for (int i = 0; i < op_input_to_func_input.size(); ++i) { + VLOG(1) << "mapping op input " << i << " to func input " + << op_input_to_func_input[i]; + + launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; + if (!op->input_devices.empty()) { + launch_op->input_devices[op_input_to_func_input[i]] = + op->input_devices[i]; + } + } + } + launch_op->attrs.NumInputs(op->inputs.size()); + + TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), + const_input_types.size()); + + // Set Targs and Nresources attrs. + TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), + arg_input_types.size()); + const int num_resource_inputs = fdef->signature().input_arg_size() - + const_input_types.size() - + arg_input_types.size(); + TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); + + // Set Tresults attr. + std::vector<TF_DataType> tresults; + for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { + tresults.push_back(static_cast<TF_DataType>(arg.type())); + } + TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), + tresults.size()); + + // Set function attr. + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(fdef->signature().name()); + launch_op->attrs.Set("function", attr_value); + + return launch_op; +} +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, @@ -531,6 +766,18 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU tensorflow::Device* device = (op->device == nullptr) ? ctx->devices()[0] : op->device; + +#ifdef TENSORFLOW_EAGER_USE_XLA + std::unique_ptr<TFE_Op> xla_launch_op; + if (op->use_xla && op->name != "_XlaLaunch") { + xla_launch_op = BuildXlaLaunch(op, status); + if (!status->status.ok()) { + return; + } + op = xla_launch_op.get(); + } +#endif // TENSORFLOW_EAGER_USE_XLA + std::vector<tensorflow::Tensor> outputs(1); const tensorflow::MemoryTypeVector* output_memory_types = nullptr; tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 6a2aff1591..9506cf7390 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -158,6 +158,13 @@ TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status); +// When 'enable' is set to 1, and if TensorFlow library is built with XLA +// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. +// +// If the library is not built with XLA support, this call would be a no-op. +TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, + unsigned char enable); + TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index f2abffb7bc..7b9f1db02e 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -121,6 +121,7 @@ struct TFE_Op { std::vector<tensorflow::Tensor> inputs; std::vector<tensorflow::Device*> input_devices; tensorflow::Device* device; + bool use_xla = false; }; #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index b0409af87c..4a3ecbc0ab 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -60,6 +60,38 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + // If there is a GPU device, returns true and sets 'gpu_device_name' // accordingly. bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { @@ -410,7 +442,7 @@ TEST(CAPI, SetAndGetOpDevices) { TF_DeleteStatus(status); } -TEST(CAPI, Execute) { +TEST(CAPI, Execute_MatMul_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -443,6 +475,117 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} + +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Execute_MatMul_XLA_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + TFE_OpSetXLACompilation(matmul, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + // Running a primitive TF operator via XLA is not yet supported. + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + EXPECT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_Min_XLA_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + + TFE_OpSetXLACompilation(minOp, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + TEST(CAPI, ExecuteWithTracing) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -484,7 +627,7 @@ TEST(CAPI, ExecuteWithTracing) { TF_DeleteStatus(status); } -TEST(CAPI, Function) { +TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); TF_OperationDescription* arg_descr = @@ -545,6 +688,72 @@ TEST(CAPI, Function) { TF_DeleteStatus(status); } +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Function_ident_XLA_CPU) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); + + std::vector<TFE_TensorHandle*> result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0540260efd..bc46918df9 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -132,7 +132,10 @@ tf_library( config = "test_graph_tfadd.config.pbtxt", cpp_class = "AddComp", graph = "test_graph_tfadd.pbtxt", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -143,7 +146,10 @@ tf_library( config = "test_graph_tfunknownop.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -155,7 +161,10 @@ tf_library( config = "test_graph_tfunknownop2.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -166,7 +175,10 @@ tf_library( config = "test_graph_tfunknownop3.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) # Utility library for benchmark binaries, used by the *_benchmark rules that are diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7dfd49cc3b..43d8ae4108 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -74,7 +74,10 @@ tf_library( # compile but the others in this directory succeed, you may need to # expand the "required by all tf_library targets" list in tfcompile.bzl. include_standard_runtime_deps = False, - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -84,7 +87,10 @@ tf_library( cpp_class = "AddWithCkptComp", freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt", graph = "test_graph_tfadd_with_ckpt.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -95,7 +101,10 @@ tf_library( freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt", freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver", graph = "test_graph_tfadd_with_ckpt_saver.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -104,7 +113,10 @@ tf_library( config = "test_graph_tffunction.config.pbtxt", cpp_class = "FunctionComp", graph = "test_graph_tffunction.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -113,7 +125,10 @@ tf_library( config = "test_graph_tfgather.config.pbtxt", cpp_class = "GatherComp", graph = "test_graph_tfgather.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -122,7 +137,10 @@ tf_library( config = "test_graph_tfmatmul.config.pbtxt", cpp_class = "foo::bar::MatMulComp", graph = "test_graph_tfmatmul.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_library( @@ -131,7 +149,10 @@ tf_library( config = "test_graph_tfmatmulandadd.config.pbtxt", cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) @@ -141,13 +162,19 @@ tf_library( config = "test_graph_tfsplits.config.pbtxt", cpp_class = "SplitsComp", graph = "test_graph_tfsplits.pb", - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], - tags = ["manual"], + tags = [ + "manual", + "notap", + ], deps = [ ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 454f0aeae9..1a8858ccce 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +81,7 @@ TEST(XlaCompilationTest, Chains) { ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); ops::UnaryOp("Relu", e, builder.opts().WithName("F")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -105,7 +106,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { Node* b = ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -125,7 +126,7 @@ TEST(XlaCompilationTest, CompilableCycles) { .WithAttr("value", Tensor())); Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -148,7 +149,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) { .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -177,7 +178,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { concat_builder.Input(dim).Input({a, a}).Attr("N", 2); builder.opts().FinalizeBuilder(&concat_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -212,7 +213,7 @@ TEST(XlaCompilationTest, FunctionCalls) { Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); @@ -244,7 +245,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); ops::UnaryOp("Shape", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -330,7 +331,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { d_builder.Input({c, c}); builder.opts().FinalizeBuilder(&d_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -382,7 +383,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { ops::BinaryOp( "MatMul", a, b, builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -413,7 +414,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { ops::BinaryOp( "Add", b, c, builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -443,7 +444,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { "Relu", a, builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -484,7 +485,7 @@ TEST(XlaCompilationTest, Resources) { Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -541,7 +542,7 @@ TEST(XlaCompilationTest, Retval) { .WithAttr("T", DT_FLOAT) .WithAttr("index", 0)); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e43ea50af4..0f08eb3a32 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -61,13 +61,12 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -83,13 +82,12 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Constant())); } @@ -110,13 +108,12 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); } @@ -133,13 +130,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -156,13 +152,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -178,13 +173,12 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -200,13 +194,12 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); } @@ -226,15 +219,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -255,15 +247,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -289,8 +280,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -298,7 +288,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -320,15 +310,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -349,15 +338,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -380,15 +368,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -411,12 +398,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Divide(op::Constant(), constant))); @@ -438,11 +424,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Power(base, op::Multiply(exp1, exp2))); } @@ -451,24 +436,23 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { Shape r0c64 = ShapeUtil::MakeShape(C64, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r0c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); - builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. @@ -482,13 +466,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -504,13 +487,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -529,13 +511,12 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -554,13 +535,12 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -579,13 +559,12 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } @@ -607,13 +586,12 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } @@ -633,15 +611,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -662,15 +639,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -689,15 +665,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -716,15 +691,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -741,14 +715,13 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -770,15 +743,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -795,14 +767,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -820,14 +791,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -849,14 +819,13 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -872,14 +841,13 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } @@ -895,14 +863,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -941,16 +908,15 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_base_dilation(1); dim->set_window_reversal(false); // Create add computation. - std::unique_ptr<HloModule> module = CreateNewModule(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -969,7 +935,6 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { dim->set_base_dilation(1); } // Create add computation. - std::unique_ptr<HloModule> module = CreateNewModule(); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -980,20 +945,20 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module->AddEmbeddedComputation(builder.Build()); + add_computation = module().AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))), window, add_computation)); - module->AddEntryComputation(builder.Build()); + module().AddEntryComputation(builder.Build()); HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -1014,14 +979,13 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), padding)); - std::unique_ptr<HloModule> module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + module().AddEntryComputation(builder.Build()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Pad(param, op::Constant())); HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } @@ -1039,17 +1003,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - auto module = CreateNewModule(); - module->AddEntryComputation(std::move(computation)); + module().AddEntryComputation(std::move(computation)); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), op); + EXPECT_THAT(module().entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. @@ -1060,14 +1023,13 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } @@ -1081,14 +1043,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -1102,14 +1063,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -1132,8 +1092,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -1141,7 +1100,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -1163,15 +1122,14 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } @@ -1188,14 +1146,13 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { HloInstruction* broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r1f32, param1, {})); builder.AddInstruction(HloInstruction::CreateConcatenate( - param0->shape(), {broadcast, param0}, 0)); + ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } @@ -1209,8 +1166,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1220,7 +1176,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -1236,8 +1192,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1247,7 +1202,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1268,14 +1223,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1314,8 +1268,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, @@ -1323,7 +1276,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1344,8 +1297,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1353,7 +1305,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Maximum(param, zero))); } @@ -1371,8 +1323,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.}))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1380,7 +1331,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1405,9 +1356,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1424,14 +1374,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2<float>({{0, 0}, {0, 0}}))))); - builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + builder.AddInstruction( + HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, + /*broadcast_dimensions=*/{0, 1})); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { @@ -1448,14 +1398,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1475,14 +1424,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1501,15 +1449,14 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } @@ -1529,14 +1476,13 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } @@ -1554,14 +1500,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector<int64>({2, 1, 0}), @@ -1576,17 +1521,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 5, 1}), param0)); builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); + ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1601,15 +1545,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1623,15 +1566,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -1646,15 +1588,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1670,15 +1611,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector<int64> broadcast_dims = @@ -1696,15 +1636,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -2410,12 +2349,11 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - auto module = CreateNewModule(); - module->AddEmbeddedComputation(std::move(dot_computation)); - module->AddEntryComputation(call_builder.Build()); + module().AddEmbeddedComputation(std::move(dot_computation)); + module().AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. @@ -2428,12 +2366,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal::CreateR1<float>(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -2453,11 +2390,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))), /*slice_sizes=*/{10, 100, 1000})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -2487,11 +2423,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } @@ -2554,15 +2489,16 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { PaddingConfig padding = window_util::MakeSymmetricPadding( decorate_spatials(param.symmetric_pad_spatials, 0, 0)); + TF_ASSERT_OK_AND_ASSIGN( + const Shape pad_shape, + ShapeInference::InferPadShape(input->shape(), + ShapeUtil::MakeShape(F32, {}), padding)); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( - ShapeUtil::MakeShape( - F32, decorate_spatials(param.reduce_window_spatials, 128, 2048)), - input, + pad_shape, input, builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), padding)); - std::unique_ptr<HloModule> module = CreateNewModule(); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -2573,24 +2509,24 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module->AddEmbeddedComputation(builder.Build()); + add_computation = module().AddEmbeddedComputation(builder.Build()); } - TF_ASSERT_OK_AND_ASSIGN( - const Shape output_shape, - ShapeInference::InferPadShape(input_shape, ShapeUtil::MakeShape(F32, {}), - padding)); Window window = window_util::MakeWindow( decorate_spatials(param.reduce_window_spatials, 1, 1)); auto zero = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, + ShapeInference::InferReduceWindowShape( + pad->shape(), zero->shape(), window, + add_computation->ComputeProgramShape())); builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -2667,11 +2603,10 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -2699,7 +2634,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloTestBase, + : public HloVerifiedTestBase, public ::testing::WithParamInterface<DotOfConcatTestSpec> {}; // Test that we transform @@ -2745,11 +2680,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -2790,17 +2724,17 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { HloInstruction* lhs2 = builder.AddInstruction( HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); HloInstruction* lhs3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, lhs2_shape, "lhs3")); + HloInstruction::CreateParameter(3, lhs3_shape, "lhs3")); Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateConcatenate( lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); auto* rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( - /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m))); + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -2810,11 +2744,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction( HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2f02591631..1a91dd8ff7 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -497,6 +497,7 @@ cc_library( "llvm_ir_runtime.h", ], deps = [ + ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", @@ -852,6 +853,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "@llvm//:core", + "@llvm//:support", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 04b4a8c5c8..2723661712 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -200,25 +200,16 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl( std::vector<llvm::VecDesc> vector_functions; const llvm::VecDesc four_wide_vector_functions_neon[] = { - {"expf", runtime::kExpV4F32NEONSymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32NEONSymbolName, 4}, - {"logf", runtime::kLogV4F32NEONSymbolName, 4}, {"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4}, }; const llvm::VecDesc four_wide_vector_functions_sse[] = { - {"expf", runtime::kExpV4F32SSESymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32SSESymbolName, 4}, - {"logf", runtime::kLogV4F32SSESymbolName, 4}, {"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4}, }; const llvm::VecDesc eight_wide_vector_functions_avx[] = { - {"expf", runtime::kExpV8F32AVXSymbolName, 8}, - {"llvm.exp.f32", runtime::kExpV8F32AVXSymbolName, 8}, - {"logf", runtime::kLogV8F32AVXSymbolName, 8}, {"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8}, }; @@ -231,6 +222,12 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl( {"tanhf", runtime::kTanhV8F32SymbolName, 8}, {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8}, + + {"expf", runtime::kExpV4F32SymbolName, 4}, + {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4}, + + {"expf", runtime::kExpV8F32SymbolName, 8}, + {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, }; llvm::SmallVector<llvm::StringRef, 32> features; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc index b1c1142e8d..62bb87f2b0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc @@ -20,11 +20,6 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #ifdef TF_XLA_HAS_AVX -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x) { - return Eigen::internal::pexp(x); -} - xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( xla::cpu::runtime::V8F32AVX x) { return Eigen::internal::plog(x); @@ -35,7 +30,6 @@ namespace xla { namespace cpu { namespace runtime { -const char *const kExpV8F32AVXSymbolName = "__xla_cpu_runtime_ExpV8F32AVX"; const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX"; } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h index e5c782f93f..f473c689f2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h @@ -33,7 +33,6 @@ namespace xla { namespace cpu { namespace runtime { -extern const char *const kExpV8F32AVXSymbolName; extern const char *const kLogV8F32AVXSymbolName; #ifdef TF_XLA_HAS_AVX diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc index d8ecf231cc..1d5b5c2c1e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc @@ -21,12 +21,6 @@ limitations under the License. #ifdef TF_XLA_HAS_SSE4_1 -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::pexp(p); -} - xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( xla::cpu::runtime::V4F32SSE x) { Eigen::internal::Packet4f p = x; @@ -39,7 +33,6 @@ namespace xla { namespace cpu { namespace runtime { -const char *const kExpV4F32SSESymbolName = "__xla_cpu_runtime_ExpV4F32SSE"; const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE"; } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h index aeb1eda23f..3b3d18112a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h @@ -35,7 +35,6 @@ namespace xla { namespace cpu { namespace runtime { -extern const char *const kExpV4F32SSESymbolName; extern const char *const kLogV4F32SSESymbolName; #ifdef TF_XLA_HAS_SSE4_1 @@ -52,9 +51,6 @@ extern "C" { // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x); - xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( xla::cpu::runtime::V4F32SSE x); #endif diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 0336fa6131..38fcd278e9 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -28,6 +29,8 @@ namespace runtime { const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; +const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; +const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; namespace { llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, @@ -42,27 +45,22 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, } llvm::LLVMContext* context = &module->getContext(); - llvm::Type* float_type = llvm::Type::getFloatTy(*context); - llvm::VectorType* vector_type = - llvm::VectorType::get(float_type, vector_width); llvm::BasicBlock* vector_tanh_body = llvm::BasicBlock::Create(*context, "body", vector_tanh_function); llvm::IRBuilder<> ir_builder(vector_tanh_body); - llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); + llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vector_type); + CHECK_EQ(input->getType(), vsl.vector_type()); // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(vector_type, -9.0), - &ir_builder), - llvm::ConstantFP::get(vector_type, 9.0), &ir_builder); + llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/-9.0, /*high=*/9.0); std::array<float, 7> numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, @@ -73,31 +71,105 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, 4.89352518554385e-03f}; - llvm::Value* input_squared = - ir_builder.CreateFMul(input_clamped, input_clamped); - llvm::Value* numerator = - llvm::ConstantFP::get(vector_type, numerator_coeffs[0]); + llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); + llvm::Value* numerator = vsl.SplatFloat(numerator_coeffs[0]); for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, numerator), - llvm::ConstantFP::get(vector_type, numerator_coeffs[i])); + numerator = vsl.MulAdd(input_squared, numerator, numerator_coeffs[i]); } - numerator = ir_builder.CreateFMul(input_clamped, numerator); - llvm::Value* denominator = - llvm::ConstantFP::get(vector_type, denominator_coeffs[0]); + numerator = vsl.Mul(input_clamped, numerator); + + llvm::Value* denominator = vsl.SplatFloat(denominator_coeffs[0]); for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, denominator), - llvm::ConstantFP::get(vector_type, denominator_coeffs[i])); + denominator = vsl.MulAdd(input_squared, denominator, denominator_coeffs[i]); } - llvm::Value* result = ir_builder.CreateFDiv(numerator, denominator); + llvm::Value* result = vsl.Div(numerator, denominator); ir_builder.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; } + +llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_exp_function = module->getFunction(function_name); + if (vector_exp_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_exp_body = + llvm::BasicBlock::Create(*context, "body", vector_exp_function); + + llvm::IRBuilder<> ir_builder(vector_exp_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); + + // This implements the same polynomial approximation as implemented in Eigen3. + + const double exp_hi = 88.3762626647950; + const double exp_lo = -88.3762626647949; + + const double cephes_LOG2EF = 1.44269504088896341; + const double cephes_exp_C1 = 0.693359375; + const double cephes_exp_C2 = -2.12194440e-4; + + const double cephes_exp_p0 = 1.9875691500E-4; + const double cephes_exp_p1 = 1.3981999507E-3; + const double cephes_exp_p2 = 8.3334519073E-3; + const double cephes_exp_p3 = 4.1665795894E-2; + const double cephes_exp_p4 = 1.6666665459E-1; + const double cephes_exp_p5 = 5.0000001201E-1; + + llvm::Value* input = &*vector_exp_function->arg_begin(); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); + llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, 0.5)); + llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); + llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); + llvm::Value* x = vsl.Sub(input_clamped, tmp); + x = vsl.Sub(x, z); + z = vsl.Mul(x, x); + + llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); + y = vsl.MulAdd(y, x, cephes_exp_p2); + y = vsl.MulAdd(y, x, cephes_exp_p3); + y = vsl.MulAdd(y, x, cephes_exp_p4); + y = vsl.MulAdd(y, x, cephes_exp_p5); + y = vsl.MulAdd(y, z, x); + y = vsl.Add(1.0, y); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + // fx is clamped so we don't have to worry about it being out of range for + // i32. + llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); + emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); + emm0 = ir_builder.CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); + + llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + + ir_builder.CreateRet(result); + + CHECK(!llvm::verifyFunction(*vector_exp_function)); + return vector_exp_function; +} } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { @@ -108,11 +180,18 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, /*vector_width=*/8, enable_fast_math); + auto* exp_v4f32 = + EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* exp_v8f32 = + EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + // Gather all the call sites, force inline them and then delete the vector // function bodies. std::vector<llvm::CallInst*> calls_to_inline; - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) { if (function != nullptr) { for (auto* user : function->users()) { calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user)); @@ -125,7 +204,7 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); } - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) { if (function != nullptr) { function->eraseFromParent(); } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h index 7f31fb98b0..90050c4459 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -25,6 +25,8 @@ namespace runtime { extern const char* const kTanhV4F32SymbolName; extern const char* const kTanhV8F32SymbolName; +extern const char* const kExpV4F32SymbolName; +extern const char* const kExpV8F32SymbolName; // The following CPU runtime functions have LLVM-IR only implementations: // @@ -40,4 +42,4 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math); } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index de5e9b4119..34c8e31060 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -216,15 +216,12 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); #ifdef TF_XLA_HAS_NEON - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); #endif #ifdef TF_XLA_HAS_SSE4_1 - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); #endif #ifdef TF_XLA_HAS_AVX - REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); #endif REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); @@ -273,15 +270,15 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); - REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); - REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(log, double (*)(double)); REGISTER_LIBM_SYMBOL(log10, double (*)(double)); REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); REGISTER_LIBM_SYMBOL(log2, double (*)(double)); REGISTER_LIBM_SYMBOL(logb, double (*)(double)); - REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); - REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); @@ -292,7 +289,8 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); REGISTER_LIBM_SYMBOL(rint, double (*)(double)); REGISTER_LIBM_SYMBOL(round, double (*)(double)); - REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbln, + double (*)(double, long)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 128b465be2..ec4215b468 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -35,8 +36,27 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); } +static string TypeToString(llvm::Type* type) { + std::string o; + llvm::raw_string_ostream ostream(o); + type->print(ostream); + return ostream.str(); +} + +void VectorSupportLibrary::AssertCorrectTypes( + std::initializer_list<llvm::Value*> values) { + for (llvm::Value* v : values) { + llvm::Type* type = v->getType(); + if (type != scalar_type() && type != vector_type()) { + LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or " + << TypeToString(vector_type()) << " but got " + << TypeToString(type); + } + } +} + llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { - CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + AssertCorrectTypes({lhs, rhs}); return MulInternal(lhs, rhs); } @@ -50,10 +70,50 @@ llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, } llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { - CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + AssertCorrectTypes({lhs, rhs}); return AddInternal(lhs, rhs); } +llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return ir_builder()->CreateFSub(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); + } else { + LOG(FATAL) << "Max for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { + AssertCorrectTypes({a}); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, + {a->getType()}, ir_builder()); +} + +llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFDiv(lhs, rhs, name()); + } else { + LOG(FATAL) << "Division for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, double low, + double high) { + AssertCorrectTypes({a}); + llvm::Type* type = a->getType(); + CHECK_LT(low, high); + CHECK(scalar_type_->isFloatingPointTy()); + return llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(a, llvm::ConstantFP::get(type, low), ir_builder_), + llvm::ConstantFP::get(type, high), ir_builder_); +} + llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { @@ -93,6 +153,7 @@ llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { void VectorSupportLibrary::StoreVector(llvm::Value* value, llvm::Value* pointer) { + AssertCorrectTypes({value}); if (pointer->getType() != vector_pointer_type()) { pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); } @@ -102,6 +163,7 @@ void VectorSupportLibrary::StoreVector(llvm::Value* value, void VectorSupportLibrary::StoreScalar(llvm::Value* value, llvm::Value* pointer) { + AssertCorrectTypes({value}); if (pointer->getType() != scalar_pointer_type()) { pointer = ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 8fbac2a667..5c5d703db5 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -41,16 +41,42 @@ class VectorSupportLibrary { llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { return Mul(ir_builder()->getInt64(lhs), rhs); } + llvm::Value* Mul(double lhs, llvm::Value* rhs) { + return Mul(llvm::ConstantFP::get(rhs->getType(), lhs), rhs); + } llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Add(int64 lhs, llvm::Value* rhs) { return Add(ir_builder()->getInt64(lhs), rhs); } + llvm::Value* Add(double lhs, llvm::Value* rhs) { + return Add(llvm::ConstantFP::get(vector_type(), lhs), rhs); + } + + llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { return Add(c, Mul(a, b)); } + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, double c) { + return Add(llvm::ConstantFP::get(vector_type(), c), Mul(a, b)); + } + + llvm::Value* MulAdd(llvm::Value* a, double b, double c) { + return Add(llvm::ConstantFP::get(a->getType(), c), + Mul(a, llvm::ConstantFP::get(a->getType(), b))); + } + + llvm::Value* Floor(llvm::Value* a); + + llvm::Value* Clamp(llvm::Value* a, double low, double high); + llvm::Value* SplatFloat(double d) { + return llvm::ConstantFP::get(vector_type(), d); + } + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, @@ -144,6 +170,11 @@ class VectorSupportLibrary { llvm::Value* AddReduce(llvm::Value* vector); + // Checks that each value in `values` is either of type scalar_type() or + // vector_type(). This LOG(FATAL)'s so it should only be called in cases + // where a mismatching type is a programmer bug. + void AssertCorrectTypes(std::initializer_list<llvm::Value*> values); + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The // resulting IR for an 8-float wide vector is expected to lower to a single // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c83880e030..9171e859c6 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,28 +44,12 @@ namespace interpreter { namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::interpreter; -/* - * Run optimization passes on the module. The graph is transformed by - * each pass in the optimization pipeline. The service subdirectory - * contains useful optimization passes. - */ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); - pipeline.AddPass<Inliner>(); - pipeline.AddPass<HloSubcomputationUnification>(); - pipeline.AddPass<HloCSE>(false); - - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( - false, [](const Shape&, const Shape&) { return false; }); - pipeline.AddPass<WhileLoopSimplifier>(); - pipeline.AddPass<ReshapeMover>(); - pipeline.AddPass<HloConstantFolding>(); - pipeline.AddPass<HloCSE>(true); + pipeline.AddPass<LayoutAssignment>( hlo_module->mutable_entry_computation_layout()); - pipeline.AddPass<HloDCE>(); - pipeline.AddPass<FlattenCallGraph>(); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index b788631fa3..87ac7731ba 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2048,47 +2048,79 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh - // implementation. - ComputationBuilder builder(client_, TestName()); - auto input_literal = Literal::CreateR2<float>( - {{1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80}, - {-0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25}, - {0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91}, - {-1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 0.04, 1.02}, - {0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81}, - {0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26}, - {-1.29, 1.35, 0.08, -1.24, -0.92, 0.49, 1.17, -0.45}, - {-1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}}); - auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateR1<float>( + {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, + -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, + -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, + 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, + 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, + 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, + -0.79, 1.41, 1.21, 1.05}); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + client_->TransferToServer(*input_literal)); auto input = builder.Parameter(0, input_literal->shape(), "input"); builder.Tanh(input); - ComputeAndCompareR2<float>( + ComputeAndCompareR1<float>( &builder, - {{0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, - -0.71985596, -0.45764771, 0.66664988}, - {-0.58278900, 0.16050975, -0.06770509, 0.36843640, -0.38476998, - 0.04018109, 0.87562293, 0.84788644}, - {0.38603750, 0.57294142, -0.79140943, 0.31032649, -0.89590985, - -0.64770776, -0.79625875, 0.72234446}, - {-0.77389336, -0.28871772, -0.80428445, -0.82541436, 0.90456349, - -0.68856895, 0.03877772, 0.76877952}, - {0.32561871, -0.54546672, 0.39072621, 0.07273290, -0.01924866, - 0.88924897, -0.55283129, 0.67183107}, - {0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, - -0.41581789, 0.72739530, 0.85025692}, - {-0.85931867, 0.87357593, 0.07782833, -0.84597743, -0.72748238, - 0.45396307, 0.82449573, -0.42462519}, - {-0.86363792, -0.89368379, -0.12621804, -0.86445558, -0.65565848, - 0.88789743, 0.83566397, 0.78287679}}, + {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, + -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, + -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, + 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, + -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, + -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, + 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, + 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, + 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, + -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, + 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, + -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, + -0.65565848, 0.88789743, 0.83566397, 0.78287679}, {input_data.get()}, // The error spec is unusually high here to account for the fact that we // use a rational interpolant to approximate tanh. ErrorSpec(0.004, 0.004)); } +XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + // Just to help make sense of the scales here -- exp(89) saturates float32 and + // exp(-10) is smaller than our error spec. + std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>( + {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, + -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, + -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, + -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, + -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, + 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, + 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, + 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, + 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, + 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, + 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Exp(input); + + std::vector<float> expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::exp(input_literal->Get<float>({i}))); + } + + ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 55f42ed3a4..93284b80f9 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -32,6 +32,8 @@ Window MakeWindow(tensorflow::gtl::ArraySlice<int64> sizes) { auto* dimension = window.add_dimensions(); dimension->set_size(size); dimension->set_stride(1); + dimension->set_base_dilation(1); + dimension->set_window_dilation(1); } return window; } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 3ed8cef56c..f48c2fe92d 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -71,7 +71,6 @@ py_library( "//tensorflow/contrib/metrics:metrics_py", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", - "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 46b579b889..4f6f539027 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -84,7 +84,6 @@ from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index e51e3f747b..abddadac5b 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -197,9 +197,7 @@ public class TensorFlowInferenceInterface { run(outputNames, enableStats, new String[] {}); } - /** - * An overloaded version of runInference that allows supplying targetNodeNames as well - */ + /** An overloaded version of runInference that allows supplying targetNodeNames as well */ public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { // Release any Tensors from the previous run calls. closeFetches(); @@ -211,7 +209,7 @@ public class TensorFlowInferenceInterface { runner.fetch(tid.name, tid.outputIndex); } - // Add targets. + // Add targets. for (String t : targetNodeNames) { runner.addTarget(t); } diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 57a52bf4ca..2720c43b78 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -329,8 +329,6 @@ tensorflow/contrib/nccl/kernels tensorflow/contrib/nccl/ops tensorflow/contrib/nccl/python tensorflow/contrib/nccl/python/ops -tensorflow/contrib/ndlstm -tensorflow/contrib/ndlstm/python tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops tensorflow/contrib/nearest_neighbor/python diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index e4213ea2a4..96ac60d095 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -50,6 +50,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 129c208ecd..a1c320347f 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -292,6 +292,12 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/util/*.h" diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index 77ea914380..53c2285699 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -32,7 +32,6 @@ from __future__ import print_function import argparse import codecs -import io import os import re import subprocess diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index 2798d31229..05744bec4e 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from tensorflow.contrib.copy_graph.python.util import copy_elements -from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index 56c562a3ba..933df6d71d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,7 +20,7 @@ from __future__ import print_function import time -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import lstm_ops diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto index c962638aa1..b4a39e6c68 100644 --- a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -14,7 +14,8 @@ message CheckpointableObjectGraph { // An index into `CheckpointableObjectGraph.nodes`, indicating the object // being referenced. int32 node_id = 1; - // A numeric identifier for this object within its parent. + // A numeric identifier for this object within its parent. Zero means + // unset, in which case there should be a local_name. int32 local_uid = 2; // A user-provided name for the edge. May be blank/omitted, in which case // there is no explicitly provided local name; fall back on local_uid. @@ -28,6 +29,8 @@ message CheckpointableObjectGraph { // The full name of the variable. Used to allow name-based loading of // checkpoints which were saved using an object-based API. string full_name = 2; + // The generated name of the variable in the checkpoint. + string checkpoint_key = 3; } message SlotVariableReference { @@ -42,6 +45,8 @@ message CheckpointableObjectGraph { // The full name of the slot variable. Used to allow name-based loading of // checkpoints which were saved using an object-based API. string full_name = 4; + // The generated name of the variable in the checkpoint. + string checkpoint_key = 5; } // Objects which this object depends on. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index e984c63af7..3df056b070 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -226,8 +226,12 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", ], ) @@ -243,6 +247,8 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py index b141ffb2bc..47ce5897c0 100644 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -19,28 +19,30 @@ from __future__ import print_function import collections import re +import weakref from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import slot_creator +from tensorflow.python.training import training _CheckpointableReference = collections.namedtuple( "_CheckpointableReference", [ - "name", # The local name if explicitly specified, else None. - "local_uid", # 0 for the first dependency, 1 for the next, ... Used for - # routing checkpointed variables to their correct - # Checkpointables when "name" is not set (see docstring of - # `track_checkpointable`). - "ref" # The Checkpointable object being referenced. - ]) - -_OwnedVariable = collections.namedtuple( - "_OwnedVariable", - [ - "name", # The variable's (local) name. - "variable" # The owned variable object. + # The local name if explicitly specified, else None. + "name", + # 1 for the first dependency, 2 for the next, ... Used for routing + # checkpointed variables to their correct Checkpointables when "name" is + # not set (see docstring of `track_checkpointable`). + "local_uid", + # The Checkpointable object being referenced. + "ref" ]) # Validation regular expression for the local names of Checkpointable @@ -59,6 +61,23 @@ _VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") _OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" +def _assign_existing_variable(variable_to_restore, value_pointer): + """Set a variable from a _ValuePointer object.""" + base_type = variable_to_restore.dtype.base_dtype + with ops.colocate_with(variable_to_restore): + # TODO(allenl): Handle partitioned variables + value_to_restore, = io_ops.restore_v2( + prefix=value_pointer.save_path, + tensor_names=[value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="checkpoint_initializer") + initializer_op = state_ops.assign(variable_to_restore, value_to_restore) + variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access + if value_pointer.session is not None: + value_pointer.session.run(initializer_op) + + class Checkpointable(object): """Manages variables and dependencies on other objects. @@ -70,14 +89,18 @@ class Checkpointable(object): """ def __init__(self): - # Basically less useful OrderedDicts but without the reference cycles. - # TODO(allenl): Switch these to OrderedDict once TensorFlow supports only + # Basically a less useful OrderedDict but without the reference cycles. + # TODO(allenl): Switch this to OrderedDict once TensorFlow supports only # Python 3.6+. - self._checkpoint_dependencies = [] # A list of _CheckpointableReference - # objects. + # A list of _CheckpointableReference objects. + self._checkpoint_dependencies = [] self._dependency_names = set() - self._owned_variables = [] # A list of _OwnedVariable objects. - self._owned_variable_names = set() + # Start numbering at 1, since an un-set protocol buffer integer is + # indistinguishable from 0. + self._next_unnamed_checkpoint_dependency_uid = 1 + self._owned_variables = {} # local name -> variable object + self._deferred_restorations = {} # local name -> _VariableRestoration + # object def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): """Create a new variable object to be saved with this `Checkpointable`. @@ -101,7 +124,7 @@ class Checkpointable(object): Raises: ValueError: If the variable name is not unique. """ - if name in self._owned_variable_names: + if name in self._owned_variables: raise ValueError( ("A variable named '%s' already exists in this Checkpointable, but " "Checkpointable.add_variable called to create another with " @@ -114,12 +137,38 @@ class Checkpointable(object): getter = kwargs.pop("getter") else: getter = variable_scope.get_variable - # TODO(allenl): handle deferred loading + deferred_restoration = self._deferred_restorations.pop(name, None) + if deferred_restoration is not None: + dtype = deferred_restoration.value_pointer.dtype + base_type = dtype.base_dtype + # TODO(allenl): Handle partitioned variables here too + initializer, = io_ops.restore_v2( + prefix=deferred_restoration.value_pointer.save_path, + tensor_names=[deferred_restoration.value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="checkpoint_initializer") + # We need to un-set the shape so get_variable doesn't complain, but we + # also need to set the static shape information on the initializer if + # possible so we don't get a variable with an unknown shape. + initializer.set_shape(shape) + # Un-set shape since we're using a constant initializer + shape = None + new_variable = getter( name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) - self._owned_variables.append( - _OwnedVariable(name=name, variable=new_variable)) - self._owned_variable_names.add(name) + if deferred_restoration is not None: + if deferred_restoration.value_pointer.session is not None: + deferred_restoration.value_pointer.session.run(new_variable.initializer) + for slot_restoration in deferred_restoration.slot_restorations: + strong_ref = slot_restoration.optimizer_ref() + if strong_ref is None: + # If the optimizer object has been garbage collected, there's no need + # to create the slot variable. + continue + strong_ref._process_slot_restoration( # pylint: disable=protected-access + slot_restoration, new_variable) + self._owned_variables[name] = new_variable return new_variable def track_checkpointable(self, checkpointable, name=None): @@ -130,13 +179,15 @@ class Checkpointable(object): Variables in a checkpoint are mapped to `Checkpointable`s based on names if provided when the checkpoint was written, but otherwise use the order those - `Checkpointable`s were declared as dependencies. Both `name` arguments and - the dependency declaration order should be deterministic. + `Checkpointable`s were declared as dependencies. - There are two sufficient conditions to avoid breaking existing checkpoints - when modifying a class: (1) New dependencies must be declared after existing - dependencies, and (2) dependencies which were previously declared may never - be removed (a trivial placeholder with the same name may be used instead). + There are three sufficient conditions to avoid breaking existing checkpoints + when modifying a class: (1) New un-named dependencies must be declared after + existing un-named dependencies, (2) un-named dependencies which were + previously declared may never be removed (a trivial placeholder may be used + instead if the dependency is no longer needed), and (3) names may not change + (un-named dependencies may not later be named, named dependencies must keep + the same name). Args: checkpointable: A `Checkpointable` which this object depends on. @@ -172,16 +223,62 @@ class Checkpointable(object): "a Checkpointable with this name is already declared as a " "dependency. If provided, names must be unique.") % (name,)) self._dependency_names.add(name) + local_uid = None + else: + # TODO(allenl): Should this be exposed to allow users to stop depending on + # things and still load checkpoints when not using names? + local_uid = self._next_unnamed_checkpoint_dependency_uid + self._next_unnamed_checkpoint_dependency_uid += 1 self._checkpoint_dependencies.append( _CheckpointableReference( - name=name, - ref=checkpointable, - # TODO(allenl): Should this be exposed to allow users to stop - # depending on things and still load checkpoints when not using - # names? - local_uid=len(self._checkpoint_dependencies))) + name=name, ref=checkpointable, local_uid=local_uid)) return checkpointable + def _process_restoration(self, restoration): + """Restore a variable and its slot variables (may be deferred).""" + variable_to_restore = self._owned_variables.get(restoration.name, None) + if variable_to_restore is not None: + # This variable already exists, so just do an assignment for this and any + # slot variables which depend on it. + _assign_existing_variable( + variable_to_restore, value_pointer=restoration.value_pointer) + for slot_restoration in restoration.slot_restorations: + strong_ref = slot_restoration.optimizer_ref() + if strong_ref is None: + continue + strong_ref._process_slot_restoration( # pylint: disable=protected-access + slot_restoration, variable_to_restore) + else: + # Save this restoration for later. This intentionally overwrites any + # previous deferred restorations, since that gives the same semantics as + # direct assignment. + self._deferred_restorations[restoration.name] = restoration + + def _process_slot_restoration(self, slot_restoration, variable): + """Restore a slot variable's value (creating it if necessary).""" + # TODO(allenl): Move this to Optimizer + assert isinstance(self, optimizer_lib.Optimizer) + named_slots = self._slot_dict(slot_restoration.slot_name) + variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access + existing_slot_variable = named_slots.get(variable_key, None) + if existing_slot_variable is None: + base_dtype = slot_restoration.value_pointer.dtype.base_dtype + initializer, = io_ops.restore_v2( + prefix=slot_restoration.value_pointer.save_path, + tensor_names=[slot_restoration.value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_dtype], + name="checkpoint_initializer") + new_slot_variable = slot_creator.create_slot(variable, initializer, + slot_restoration.slot_name) + if slot_restoration.value_pointer.session is not None: + slot_restoration.value_pointer.session.run( + new_slot_variable.initializer) + named_slots[variable_key] = new_slot_variable + else: + _assign_existing_variable( + existing_slot_variable, value_pointer=slot_restoration.value_pointer) + @property def checkpoint_dependencies(self): """Other `Checkpointable` objects on which this object depends.""" @@ -237,9 +334,9 @@ def _variable_naming_for_object(path_to_root): if object_prefix: object_prefix += "/" - def _name_single_variable(owned_variable): + def _name_single_variable(local_name): """Names a variable within an object.""" - return object_prefix + _escape_variable_name(owned_variable.name) + return object_prefix + _escape_variable_name(local_name) return _name_single_variable @@ -289,26 +386,31 @@ def _serialize_non_slot_variables(checkpointable_objects, path_to_root, for checkpoint_id, checkpointable in enumerate(checkpointable_objects): naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) object_proto = object_graph_proto.nodes.add() - for owned_variable in checkpointable.ref._owned_variables: # pylint: disable=protected-access - variable_name = naming_scheme(owned_variable) - named_variables[variable_name] = owned_variable.variable + for (local_name, owned_variable) in sorted( + checkpointable.ref._owned_variables.items(), # pylint: disable=protected-access + key=lambda x: x[0]): + variable_name = naming_scheme(local_name) + named_variables[variable_name] = owned_variable non_slot_variables.append(( variable_name, # The variable's full checkpoint name - owned_variable, # The variable's _OwnedVariable object + owned_variable, # The variable object + local_name, # The variable's local name checkpoint_id)) # The checkpoint ID of the node which owns this # variable. variable_proto = object_proto.variables.add() - variable_proto.local_name = owned_variable.name + variable_proto.local_name = local_name + variable_proto.checkpoint_key = variable_name # Figure out the name-based Saver's name for this variable. saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [owned_variable.variable], convert_variable_to_tensor=False) + [owned_variable], convert_variable_to_tensor=False) variable_full_name, = saver_dict.keys() variable_proto.full_name = variable_full_name for child in checkpointable.ref.checkpoint_dependencies: child_proto = object_proto.children.add() child_proto.node_id = checkpoint_node_ids[child] - child_proto.local_uid = child.local_uid + if child.local_uid is not None: + child_proto.local_uid = child.local_uid if child.name is not None: child_proto.local_name = child.name return named_variables, non_slot_variables @@ -326,24 +428,25 @@ def _serialize_slot_variables(checkpointable_objects, path_to_root, optimizer=checkpointable_ref.ref, path_to_root=path_to_root[checkpointable_ref]) slot_names = checkpointable_ref.ref.get_slot_names() - for (variable_path, owned_variable, + for (variable_path, original_variable, original_variable_local_name, original_node_checkpoint_id) in non_slot_variables: for slot_name in slot_names: slot_variable = checkpointable_ref.ref.get_slot( - owned_variable.variable, slot_name) + original_variable, slot_name) if slot_variable is not None: checkpoint_name = naming_scheme( variable_path=variable_path, slot_name=slot_name) named_slot_variables[checkpoint_name] = slot_variable slot_variable_proto = optimizer_object_proto.slot_variables.add() slot_variable_proto.slot_name = slot_name + slot_variable_proto.checkpoint_key = checkpoint_name # Figure out the name-based Saver's name for this variable. saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( [slot_variable], convert_variable_to_tensor=False) slot_variable_full_name, = saver_dict.keys() slot_variable_proto.full_name = slot_variable_full_name slot_variable_proto.original_variable_local_name = ( - owned_variable.name) + original_variable_local_name) slot_variable_proto.original_variable_node_id = ( original_node_checkpoint_id) return named_slot_variables @@ -390,3 +493,250 @@ def _serialize_object_graph(root_checkpointable): named_variables.update(named_slot_variables) return named_variables, object_graph_proto + + +def _set_reference(reference_proto_table, key, checkpointable, parent, + object_id_map): + """Record a checkpoint<->object correspondence, with error checking. + + Args: + reference_proto_table: Map from names or numbers to `ObjectReference` protos + within the parent object. + key: Either a numeric or string identifier for the reference. + checkpointable: The object to record a correspondence for. + parent: The parent Python object, for creating a useful error message. + object_id_map: The map from `node_id` to Python object in which to record + the reference. + Returns: + The `node_id` of the Object proto corresponding to the specified Python + object. + Raises: + AssertionError: If another object is already bound to the `Object` proto. + """ + reference_proto = reference_proto_table[key] + set_reference = object_id_map.setdefault(reference_proto.node_id, + checkpointable) + if set_reference is not checkpointable: + raise AssertionError( + ("Unable to load the checkpoint into this object graph. Either " + "the Checkpointable object references in the Python program " + "have changed in an incompatible way, or the checkpoint was " + "generated in an incompatible program.\n\nTwo checkpoint " + "references (one being '%s' in %s) resolved to different " + "objects (%s and %s).") % (key, parent, set_reference, + checkpointable)) + return reference_proto.node_id + + +def _checkpoint_object_id_map(root_checkpointable, object_graph_proto): + """Match a checkpointed object graph to a Python object graph. + + Args: + root_checkpointable: A Checkpointable object. + object_graph_proto: A CheckpointableObjectGraph protocol buffer representing + a serialized object graph. + Returns: + A dictionary mapping from checkpoint node ids (indices into + `object_graph_proto.nodes`) to `Checkpointable` objects which are + dependencies of `root_checkpointable`. + """ + node_list = object_graph_proto.nodes + # Queue of (checkpointable object, node id) + to_visit = collections.deque([(root_checkpointable, 0)]) + object_id_map = {0: root_checkpointable} + seen = set() + while to_visit: + checkpointable, node_id = to_visit.popleft() + object_proto = node_list[node_id] + named_children = {} + numbered_children = {} + for child_reference in object_proto.children: + if child_reference.local_name: + named_children[child_reference.local_name] = child_reference + else: + if not child_reference.local_uid: + raise AssertionError( + ("The checkpointed object graph contains a reference with " + "neither a name nor a number (corrupted?). The reference was " + "from the node %s.") % (object_proto,)) + numbered_children[child_reference.local_uid] = child_reference + + for checkpointable_reference in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access + if checkpointable_reference.name is not None: + child_node_id = _set_reference( + reference_proto_table=named_children, + key=checkpointable_reference.name, + checkpointable=checkpointable_reference.ref, + parent=checkpointable, + object_id_map=object_id_map) + else: + if checkpointable_reference.local_uid is None: + raise AssertionError( + ("A Checkpointable reference was created with no name and no " + "number in %s.") % (checkpointable,)) + child_node_id = _set_reference( + reference_proto_table=numbered_children, + key=checkpointable_reference.local_uid, + checkpointable=checkpointable_reference.ref, + parent=checkpointable, + object_id_map=object_id_map) + if child_node_id not in seen: + seen.add(child_node_id) + to_visit.append((checkpointable_reference.ref, child_node_id)) + + return object_id_map + + +_ValuePointer = collections.namedtuple( + "_ValuePointer", + [ + # Information needed to look up the value to restore. + "save_path", + "checkpoint_key", + "dtype", + # The session to use when restoring (None when executing eagerly) + "session", + ]) + +_SlotVariableRestoration = collections.namedtuple( + "_SlotVariableRestoration", + [ + # A weak reference to the Optimizer object + "optimizer_ref", + # The slot name + "slot_name", + # The _ValuePointer to use when restoring + "value_pointer", + ]) + +_VariableRestoration = collections.namedtuple( + "_VariableRestoration", + [ + # The variable's (local) name. + "name", + # _SlotVariableRestoration objects indicating slot variables which + # should be created once this variable has been restored. + "slot_restorations", + # The _ValuePointer to use when restoring + "value_pointer", + ]) + + +def _gather_restorations(object_graph_proto, save_path, object_id_map, + dtype_map, session): + """Iterate over variables to restore, matching with Checkpointable objects.""" + variable_to_slot_restorations = {} + for node_id, node in enumerate(object_graph_proto.nodes): + for slot_variable in node.slot_variables: + original_variable_key = (slot_variable.original_variable_node_id, + slot_variable.original_variable_local_name) + variable_to_slot_restorations.setdefault( + original_variable_key, []).append( + _SlotVariableRestoration( + optimizer_ref=weakref.ref(object_id_map[node_id]), + slot_name=slot_variable.slot_name, + value_pointer=_ValuePointer( + save_path=save_path, + checkpoint_key=slot_variable.checkpoint_key, + dtype=dtype_map[slot_variable.checkpoint_key], + session=session))) + + for node_id, node in enumerate(object_graph_proto.nodes): + for variable in node.variables: + slots_key = (node_id, variable.local_name) + variable_restore = _VariableRestoration( + name=variable.local_name, + slot_restorations=variable_to_slot_restorations.get(slots_key, []), + value_pointer=_ValuePointer( + save_path=save_path, + checkpoint_key=variable.checkpoint_key, + dtype=dtype_map[variable.checkpoint_key], + session=session)) + yield variable_restore, object_id_map[node_id] + + +def save(file_prefix, root_checkpointable, global_step=None, session=None): + """Save a training checkpoint. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). Names are generated based on this + prefix and the global step, if provided. + root_checkpointable: A Checkpointable object to save. The checkpoint + includes variables created by this object and any Checkpointable objects + it depends on. + global_step: An integer variable or Tensor, used to number + checkpoints. Typically this value is saved along with other variables in + training checkpoints, which will happen automatically if it was created by + `root_checkpointable` or one of its dependencies (via + `Checkpointable.add_variable`). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + + Returns: + The full path to the checkpoint. + + Currently also returns the serialized object graph proto, but that will go + away once it's saved with the checkpoint. + """ + named_variables, serialized_graph = _serialize_object_graph( + root_checkpointable) + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + with ops.device("/device:CPU:0"): + save_path = saver_lib.Saver(var_list=named_variables).save( + sess=session, + save_path=file_prefix, + write_meta_graph=False, + global_step=global_step) + # TODO(allenl): Save the graph with the checkpoint, then returning it and + # taking it as an argument to restore won't be necessary. + return serialized_graph, save_path + + +# NOTE: Will be restore(file_prefix, root_checkpointable) once the object graph +# is saved with the checkpoint. +def restore(save_path, root_checkpointable, object_graph_proto, session=None): + """Restore a training checkpoint. + + Restores the values of variables created with `Checkpointable.add_variable` in + the dependency graph of `root_checkpointable`. Either assigns values + immediately (if variables to restore have been created already), or defers + restoration until the variables are created. + + When building a graph, restorations are executed in the default session if + `session` is `None`. Variable initializers read checkpointed values. + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. If None (as when there is no latest + checkpoint for `tf.train.latest_checkpoint` to return), does nothing. + root_checkpointable: The root of the object graph to restore. Variables to + restore need not have been created yet, but all dependencies on other + Checkpointable objects should already be declared. Objects in the + dependency graph are matched to objects in the checkpointed graph, and + matching objects have their variables restored (or the checkpointed values + saved for eventual restoration when the variable is created). + object_graph_proto: (Temporary) the checkpointed object graph. This will + eventually be saved with the checkpoint, and will not be part of the final + API. + session: The session to evaluate assignment ops in. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + """ + if save_path is None: + return + object_id_map = _checkpoint_object_id_map(root_checkpointable, + object_graph_proto) + reader = training.NewCheckpointReader(save_path) + dtype_map = reader.get_variable_to_dtype_map() + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + for restoration, checkpointable in _gather_restorations( + object_graph_proto, save_path, object_id_map, dtype_map, session=session): + checkpointable._process_restoration(restoration) # pylint: disable=protected-access diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py index ff419614f5..d823053283 100644 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function import functools +import os + import six from tensorflow.contrib.eager.python import checkpointable @@ -28,9 +30,12 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import adam +from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util @@ -101,11 +106,6 @@ class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): return v - # TODO(allenl): Override slot variable creation (_get_or_make_slot, - # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred - # loading. Likely no need to run this through add_variable, since gathering - # slot variables is special cased anyway. - class MyNetwork(CheckpointableNetwork): """A concrete Network for testing.""" @@ -175,8 +175,8 @@ class CheckpointNamingTests(test.TestCase): expected_checkpoint_names = ( # Created in the root node, so no prefix. "global_step", - # No name provided to track_checkpointable(), so the position (1, after - # the named track_checkpointable() which is 0) is used instead. + # No name provided to track_checkpointable(), so the position is used + # instead (one-based). "network/_1/kernel", # track_checkpointable() with a name provided, so that's used "network/named_dense/kernel", @@ -212,20 +212,158 @@ class CheckpointNamingTests(test.TestCase): 0].node_id] self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) + # Variable ordering is arbitrary but deterministic (alphabetized) self.assertEqual( - "kernel", optimizer_node.slot_variables[0].original_variable_local_name) + "bias", optimizer_node.slot_variables[0].original_variable_local_name) original_variable_owner = serialized_graph.nodes[ optimizer_node.slot_variables[0].original_variable_node_id] - self.assertEqual("kernel", original_variable_owner.variables[0].local_name) + self.assertEqual("network/named_dense/bias", + original_variable_owner.variables[0].checkpoint_key) + self.assertEqual("bias", original_variable_owner.variables[0].local_name) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + self.assertEqual("network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", + optimizer_node.slot_variables[0].checkpoint_key) # We strip off the :0 suffix, as variable.name-based saving does. - self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", + self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", optimizer_node.slot_variables[0].full_name) - self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", + self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", optimizer.get_slot( - var=named_variables["network/named_dense/kernel"], + var=named_variables["network/named_dense/bias"], name="m").name) + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(state_ops.assign(network._named.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(network._named.variables[1], "m") + self.evaluate(state_ops.assign(m_bias_slot, [1.5])) + serialized_graph, save_path = checkpointable.save( + file_prefix=prefix, + root_checkpointable=root_checkpointable, + global_step=root_checkpointable.global_step) + self.evaluate(state_ops.assign(network._named.variables[1], [43.])) + self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) + optimizer_variables = self.evaluate(optimizer.variables()) + self.evaluate(state_ops.assign(m_bias_slot, [-2.])) + # Immediate restoration + checkpointable.restore( + save_path=save_path, + root_checkpointable=root_checkpointable, + object_graph_proto=serialized_graph) + self.assertAllEqual([42.], self.evaluate(network._named.variables[1])) + self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) + self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) + with ops.Graph().as_default(): + on_create_network = MyNetwork() + on_create_optimizer = CheckpointableAdam(0.001) + on_create_root = Root( + optimizer=on_create_optimizer, network=on_create_network) + with self.test_session(graph=ops.get_default_graph()): + # Deferred restoration + checkpointable.restore( + save_path=save_path, + root_checkpointable=on_create_root, + object_graph_proto=serialized_graph) + on_create_network(constant_op.constant([[3.]])) # create variables + self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) + self.assertAllEqual([42.], + self.evaluate( + on_create_network._named.variables[1])) + on_create_m_bias_slot = on_create_optimizer.get_slot( + on_create_network._named.variables[1], "m") + # Optimizer slot variables are created when the original variable is + # restored. + self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) + # beta1_power and beta2_power haven't been created yet, but everything + # else matches. + self.assertAllEqual(optimizer_variables[2:], + self.evaluate(on_create_optimizer.variables())) + on_create_optimizer._create_slots( + [resource_variable_ops.ResourceVariable([1.])]) + beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() + self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) + self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + + def testDeferredRestorationUsageEager(self): + """An idiomatic eager execution example.""" + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + latest_object_graph = None # Will be saved with the checkpoint eventually. + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Root(optimizer=optimizer, network=network) + checkpointable.restore( + save_path=core_saver.latest_checkpoint(checkpoint_directory), + root_checkpointable=root, + object_graph_proto=latest_object_graph) + for _ in range(num_training_steps): + # TODO(allenl): Use a Dataset and serialize/checkpoint it. + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + lambda: network(input_value), # pylint: disable=cell-var-from-loop + global_step=root.global_step) + latest_object_graph, _ = checkpointable.save( + file_prefix=checkpoint_prefix, + root_checkpointable=root) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.global_step.numpy()) + + def testUsageGraph(self): + """Expected usage when graph building.""" + with context.graph_mode(): + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + latest_object_graph = None + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Root(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + train_op = optimizer.minimize( + network(input_value), + global_step=root.global_step) + init_op = variables.global_variables_initializer() + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + with self.test_session(graph=ops.get_default_graph()) as session: + if checkpoint_path is None: + self.assertEqual(0, training_continuation) + session.run(init_op) + # Another alternative would be to run initializers automatically + # if no checkpoint is being loaded. This would make deferred + # loading a bit more useful with graph execution. + else: + checkpointable.restore( + save_path=checkpoint_path, + root_checkpointable=root, + object_graph_proto=latest_object_graph, + session=session) + for _ in range(num_training_steps): + session.run(train_op) + latest_object_graph, _ = checkpointable.save( + file_prefix=checkpoint_prefix, + root_checkpointable=root, + session=session) + self.assertEqual((training_continuation + 1) * num_training_steps, + session.run(root.global_step)) + def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() with variable_scope.variable_scope("get_checkpoint_name"): @@ -255,7 +393,7 @@ class CheckpointNamingTests(test.TestCase): leaf.add_variable(name="v", shape=[]) named_variables, _ = checkpointable._serialize_object_graph(root) variable_name, = named_variables.keys() - self.assertEqual(r"_0/v", variable_name) + self.assertEqual(r"_1/v", variable_name) @test_util.run_in_graph_and_eager_modes() def testLocalNameValidation(self): diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 1f7beee685..0ff8746884 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -22,7 +22,7 @@ import gc import tempfile import time -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf import tensorflow.contrib.eager as tfe diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 40919f2d4c..aa87b94e7b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -65,7 +65,6 @@ import six import tensorflow as tf from tensorflow.contrib.eager.python import tfe -from tensorflow.python.eager import context try: import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index d34e9ea68b..5c5c59c877 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -339,8 +339,7 @@ if __name__ == "__main__": "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz") parser.add_argument( "--logdir", type=str, default="", help="Directory for checkpoint.") - parser.add_argument( - "--epoch", type=int, default=20, help="Number of epochs.") + parser.add_argument("--epoch", type=int, default=20, help="Number of epochs.") parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") parser.add_argument( "--seq-len", type=int, default=35, help="Sequence length.") diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 19b0104c80..7b2f09cba1 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -26,7 +26,7 @@ import tempfile import time import numpy as np -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf # pylint: disable=g-bad-import-order diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 81c77e41ac..3329fc6c51 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -539,7 +539,7 @@ class NetworkTest(test.TestCase): # No issue here since the name is unique within its scope. name_conflict3 = MyNetwork(name="name_conflict") net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the - # variable_scope my_network_1 below. + # variable_scope my_network_1 below. vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below with variable_scope.variable_scope("intervening_scope"): with variable_scope.variable_scope(captured_scope): diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index f72280c4ec..b2dfe48b2d 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,17 +24,16 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops as logging -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util def _streaming_sum(scalar_tensor): @@ -70,8 +69,8 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): class GMM(estimator.Estimator): """An estimator for GMM clustering.""" SCORES = 'scores' + LOG_LIKELIHOOD = 'loss' ASSIGNMENTS = 'assignments' - ALL_SCORES = 'all_scores' def __init__(self, num_clusters, @@ -113,10 +112,7 @@ class GMM(estimator.Estimator): yield result[GMM.ASSIGNMENTS] def score(self, input_fn=None, batch_size=None, steps=None): - """Predict total sum of distances to nearest clusters. - - Note that this function is different from the corresponding one in sklearn - which returns the negative of the sum of distances. + """Predict total log-likelihood. Args: input_fn: see predict. @@ -124,11 +120,11 @@ class GMM(estimator.Estimator): steps: see predict. Returns: - Total sum of distances to nearest clusters. + Total log-likelihood. """ results = self.evaluate(input_fn=input_fn, batch_size=batch_size, steps=steps) - return np.sum(results[GMM.SCORES]) + return np.log(np.sum(np.exp(results[GMM.SCORES]))) def weights(self): """Returns the cluster weights.""" @@ -158,9 +154,10 @@ class GMM(estimator.Estimator): def _model_fn(features, labels, mode, config): """Model function.""" assert labels is None, labels - (all_scores, + (loss, + scores, model_predictions, - losses, training_op, + training_op, init_op, is_initialized) = gmm_ops.gmm(self._parse_tensor_or_dict(features), self._training_initial_clusters, @@ -168,16 +165,15 @@ class GMM(estimator.Estimator): self._covariance_type, self._params) incr_step = state_ops.assign_add(training_util.get_global_step(), 1) - loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( init_op, is_initialized, config.is_chief)] predictions = { - GMM.ALL_SCORES: all_scores[0], GMM.ASSIGNMENTS: model_predictions[0][0], } eval_metric_ops = { - GMM.SCORES: _streaming_sum(loss), + GMM.SCORES: scores, + GMM.LOG_LIKELIHOOD: _streaming_sum(loss), } return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions, eval_metric_ops=eval_metric_ops, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index a61681c7f5..98d6434f47 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -21,7 +21,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.summary import summary # Machine epsilon. MEPS = np.finfo(float).eps @@ -253,14 +251,16 @@ class GmmAlgorithm(object): return ret def scores(self): - """Returns the distances to each class. + """Returns the per-sample likelihood fo the data. Returns: - A tuple with two Tensors. The first contains the distance to - each class. The second contains the distance to the assigned - class. + Log probabilities of each data point. """ - return (self._all_scores, self._scores) + return self._scores + + def log_likelihood_op(self): + """Returns the log-likelihood operation.""" + return self._log_likelihood_op def _define_graph(self, data): """Define graph for a single iteration. @@ -276,7 +276,8 @@ class GmmAlgorithm(object): self._define_expectation_operation(shard_id) self._define_partial_maximization_operation(shard_id, shard) self._define_maximization_operation(len(data)) - self._define_distance_to_clusters(data) + self._define_loglikelihood_operation() + self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): """Defines the full covariance probabilties per example in a class. @@ -440,50 +441,20 @@ class GmmAlgorithm(object): state_ops.assign( self._covs, new_covs, validate_shape=False)) - def _define_distance_to_clusters(self, data): - """Defines the Mahalanobis distance to the assigned Gaussian.""" - # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input - - # mean) from log probability function. - self._all_scores = [] - for shard in data: - all_scores = [] - shard = array_ops.expand_dims(shard, 0) - for c in xrange(self._num_classes): - if self._covariance_type == FULL_COVARIANCE: - cov = self._covs[c, :, :] - elif self._covariance_type == DIAG_COVARIANCE: - cov = array_ops.diag(self._covs[c, :]) - inverse = linalg_ops.matrix_inverse(cov + self._min_var) - inv_cov = array_ops.tile( - array_ops.expand_dims(inverse, 0), - array_ops.stack([self._num_examples, 1, 1])) - diff = array_ops.transpose(shard - self._means[c, :, :], perm=[1, 0, 2]) - m_left = math_ops.matmul(diff, inv_cov) - all_scores.append( - math_ops.sqrt( - math_ops.matmul( - m_left, array_ops.transpose( - diff, perm=[0, 2, 1])))) - self._all_scores.append( - array_ops.reshape( - array_ops.concat(all_scores, 1), - array_ops.stack([self._num_examples, self._num_classes]))) - - # Distance to the associated class. - self._all_scores = array_ops.concat(self._all_scores, 0) - assignments = array_ops.concat(self.assignments(), 0) - rows = math_ops.to_int64(math_ops.range(0, self._num_examples)) - indices = array_ops.concat( - [array_ops.expand_dims(rows, 1), array_ops.expand_dims(assignments, 1)], - 1) - self._scores = array_ops.gather_nd(self._all_scores, indices) - def _define_loglikelihood_operation(self): """Defines the total log-likelihood of current iteration.""" - self._ll_op = [] + op = [] for prior_probs in self._prior_probs: - self._ll_op.append(math_ops.reduce_sum(math_ops.log(prior_probs))) - summary.scalar('ll', math_ops.reduce_sum(self._ll_op)) + op.append(math_ops.reduce_logsumexp(prior_probs)) + self._log_likelihood_op = math_ops.reduce_logsumexp(op) + + def _define_score_samples(self): + """Defines the likelihood of each data sample.""" + op = [] + for shard_id, prior_probs in enumerate(self._prior_probs): + op.append(prior_probs + math_ops.log(self._w[shard_id])) + self._scores = array_ops.squeeze( + math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) def gmm(inp, @@ -511,14 +482,9 @@ def gmm(inp, Returns: Note: tuple of lists returned to be consistent with skflow A tuple consisting of: - all_scores: A matrix (or list of matrices) of dimensions (num_input, - num_clusters) where the value is the distance of an input vector and a - cluster center. assignments: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. - scores: Similar to assignments but specifies the distance to the - assigned cluster instead. training_op: an op that runs an iteration of training. init_op: an op that runs the initialization. """ @@ -532,6 +498,7 @@ def gmm(inp, gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params, covariance_type, random_seed) assignments = gmm_tool.assignments() - all_scores, scores = gmm_tool.scores() - return ([all_scores], [assignments], [scores], gmm_tool.training_ops(), + scores = gmm_tool.scores() + loss = gmm_tool.log_likelihood_op() + return (loss, scores, [assignments], gmm_tool.training_ops(), gmm_tool.init_ops(), gmm_tool.is_initialized()) diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py index c50e82db8a..888c3c238c 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py @@ -122,17 +122,23 @@ class GmmOpsTest(test.TestCase): g.seed = 5 with self.test_session() as sess: data = constant_op.constant(self.data, dtype=dtypes.float32) - _, assignments, _, training_op, init_op, _ = gmm_ops.gmm( + loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( data, 'random', num_classes, random_seed=self.seed) variables.global_variables_initializer().run() sess.run(init_op) + first_loss = sess.run(loss_op) for _ in xrange(self.iterations): sess.run(training_op) assignments = sess.run(assignments) + end_loss = sess.run(loss_op) + scores = sess.run(scores) + self.assertEqual((self.num_examples, 1), scores.shape) accuracy = np.mean( np.asarray(self.true_assignments) == np.squeeze(assignments)) logging.info('Accuracy: %f', accuracy) + logging.info('First loss: %f, end loss: %f', first_loss, end_loss) + self.assertGreater(end_loss, first_loss) self.assertGreater(accuracy, 0.98) def testParams(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 7717b47dae..00a4734eb6 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib from tensorflow.contrib.learn.python.learn.estimators import kmeans @@ -30,12 +29,9 @@ from tensorflow.python.framework import random_seed as random_seed_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test from tensorflow.python.training import queue_runner -FLAGS = flags.FLAGS - class GMMTest(test.TestCase): @@ -64,9 +60,8 @@ class GMMTest(test.TestCase): self.batch_size = self.num_points self.true_centers = self.make_random_centers(self.num_centers, self.num_dims) - self.points, self.assignments, self.scores = self.make_random_points( + self.points, self.assignments = self.make_random_points( self.true_centers, self.num_points) - self.true_score = np.add.reduce(self.scores) # Use initial means from kmeans (just like scikit-learn does). clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers) @@ -86,24 +81,7 @@ class GMMTest(test.TestCase): offsets = np.round( np.random.randn(num_points, num_dims).astype(np.float32) * 20) points = centers[assignments] + offsets - means = [ - np.mean( - points[assignments == center], axis=0) - for center in xrange(num_centers) - ] - covs = [ - np.cov(points[assignments == center].T) - for center in xrange(num_centers) - ] - scores = [] - for r in xrange(num_points): - scores.append( - np.sqrt( - np.dot( - np.dot(points[r, :] - means[assignments[r]], - np.linalg.inv(covs[assignments[r]])), points[r, :] - - means[assignments[r]]))) - return (points, assignments, scores) + return (points, assignments) def test_weights(self): """Tests the shape of the weights.""" @@ -136,8 +114,7 @@ class GMMTest(test.TestCase): gmm.fit(input_fn=self.input_fn(), steps=10) score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points), steps=1) - self.assertGreater(score1, score2) - self.assertNear(self.true_score, score2, self.true_score * 0.15) + self.assertLess(score1, score2) def test_infer(self): gmm = gmm_lib.GMM(self.num_centers, @@ -149,8 +126,7 @@ class GMMTest(test.TestCase): # Make a small test set num_points = 40 - points, true_assignments, true_offsets = ( - self.make_random_points(clusters, num_points)) + points, true_assignments = self.make_random_points(clusters, num_points) assignments = [] for item in gmm.predict_assignments( @@ -159,11 +135,6 @@ class GMMTest(test.TestCase): assignments = np.ravel(assignments) self.assertAllEqual(true_assignments, assignments) - # Test score - score = gmm.score(input_fn=self.input_fn(points=points, - batch_size=num_points), steps=1) - self.assertNear(score, np.sum(true_offsets), 4.05) - def _compare_with_sklearn(self, cov_type): # sklearn version. iterations = 40 diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index 8f44698da8..35974b9e21 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -27,16 +27,11 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 from tensorflow.python.eager import backprop -from tensorflow.python.eager import context as eager_context -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import gradients -from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index 6f65fe771e..45962098e9 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -22,7 +22,6 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 3f1ece4510..0754c3e0e3 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -25,6 +25,7 @@ import re from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope from tensorflow.contrib.framework.python.ops import gen_variable_ops from tensorflow.contrib.util import loader +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -32,9 +33,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util from tensorflow.python.util.deprecation import deprecated @@ -685,7 +685,8 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, 'Variable %s missing in checkpoint %s', var, model_path) var_list = available_vars if var_list: - saver = tf_saver.Saver(var_list, reshape=reshape_variables) + saver = tf_saver.Saver(var_list, reshape=reshape_variables, + write_version=saver_pb2.SaverDef.V1) def callback(session): saver.restore(session, model_path) return callback diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index d9b07e62f8..fdfabd07c1 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -202,10 +202,13 @@ def get_graph_def_from_url_tarball(url, filename, tar_filename=None): A GraphDef loaded from a file in the downloaded tarball. """ if not (tar_filename and os.path.exists(tar_filename)): + def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % ( - url, float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.write('\r>> Downloading %s %.1f%%' % + (url, + float(count * block_size) / float(total_size) * 100.0)) sys.stdout.flush() + tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) with tarfile.open(tar_filename, 'r:gz') as tar: proto_str = tar.extractfile(filename).read() diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md index cb3a1087de..163993a3f6 100644 --- a/tensorflow/contrib/hvx/README.md +++ b/tensorflow/contrib/hvx/README.md @@ -141,16 +141,16 @@ Configuring the installer for this system's environment... Launching installer... -An internal LaunchAnywhere application error has occured and this application cannot proceed. (LAX) +An internal LaunchAnywhere application error has occurred and this application cannot proceed. (LAX) Stack Trace: java.lang.IllegalArgumentException: Malformed \uxxxx encoding. - at java.util.Properties.loadConvert(Properties.java:574) - at java.util.Properties.load0(Properties.java:391) - at java.util.Properties.load(Properties.java:317) - at com.zerog.common.java.util.PropertiesUtil.loadProperties(Unknown Source) - at com.zerog.lax.LAX.<init>(Unknown Source) - at com.zerog.lax.LAX.main(Unknown Source) + at java.util.Properties.loadConvert(Properties.java:574) + at java.util.Properties.load0(Properties.java:391) + at java.util.Properties.load(Properties.java:317) + at com.zerog.common.java.util.PropertiesUtil.loadProperties(Unknown Source) + at com.zerog.lax.LAX.<init>(Unknown Source) + at com.zerog.lax.LAX.main(Unknown Source) ``` It can be solved by temporarily assigning the `PS1` environment variable to something simple, such as '$'. diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py index bf0c97245f..3f4029e558 100644 --- a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py @@ -18,13 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from six.moves import xrange # pylint: disable=redefined-builtin - from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms \ import single_image_random_dot_stereograms -from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD index f7593aa462..efb403462a 100644 --- a/tensorflow/contrib/kafka/BUILD +++ b/tensorflow/contrib/kafka/BUILD @@ -22,7 +22,7 @@ tf_kernel_library( "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:dataset", "//third_party/eigen3", - "@kafka//:kafka", + "@kafka", ], ) @@ -88,6 +88,7 @@ tf_py_test( ], tags = [ "manual", + "notap", ], ) diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py index 94cf6b5ace..621911876f 100644 --- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -18,21 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -import os - from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import io_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat + class KafkaDatasetTest(test.TestCase): @@ -64,52 +56,58 @@ class KafkaDatasetTest(test.TestCase): with self.test_session() as sess: # Basic test: read from topic 0. - sess.run( - init_op, feed_dict={topics: ["test:0:0:4"], - num_epochs: 1}) + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) for i in range(5): - self.assertEqual("D"+str(i), sess.run(get_next)) + self.assertEqual("D" + str(i), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from topic 1. - sess.run( - init_op, feed_dict={topics: ["test:0:5:-1"], - num_epochs: 1}) + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) for i in range(5): - self.assertEqual("D"+str(i + 5), sess.run(get_next)) + self.assertEqual("D" + str(i + 5), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Basic test: read from both topics. - sess.run(init_op, feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], - num_epochs: 1}) + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 1 + }) for j in range(2): for i in range(5): - self.assertEqual("D"+str(i + j * 5), sess.run(get_next)) + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test repeated iteration through both files. - sess.run(init_op, feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], - num_epochs: 10}) + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10 + }) for _ in range(10): for j in range(2): for i in range(5): - self.assertEqual("D"+str(i + j * 5), sess.run(get_next)) + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test batched and repeated iteration through both files. sess.run( init_batch_op, - feed_dict={topics: ["test:0:0:4", "test:0:5:-1"], - num_epochs: 10, - batch_size: 5}) + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5 + }) for _ in range(10): - self.assertAllEqual(["D"+str(i) for i in range(5)], + self.assertAllEqual(["D" + str(i) for i in range(5)], sess.run(get_next)) - self.assertAllEqual(["D"+str(i + 5) for i in range(5)], + self.assertAllEqual(["D" + str(i + 5) for i in range(5)], sess.run(get_next)) diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh index 7997c12731..adf027b8e7 100644 --- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -1,4 +1,18 @@ #!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e set -o pipefail diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index e561f595a4..8e51d27a34 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -18,20 +18,22 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.kafka.python.ops import gen_kafka_ops -from tensorflow.contrib.util import loader from tensorflow.python.data.ops.readers import Dataset -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.platform import resource_loader + class KafkaDataset(Dataset): """A Kafka Dataset that consumes the message. """ - def __init__( - self, topics, servers="localhost", group="", eof=False, timeout=1000): + def __init__(self, + topics, + servers="localhost", + group="", + eof=False, + timeout=1000): """Create a KafkaReader. Args: @@ -51,14 +53,13 @@ class KafkaDataset(Dataset): servers, dtype=dtypes.string, name="servers") self._group = ops.convert_to_tensor( group, dtype=dtypes.string, name="group") - self._eof = ops.convert_to_tensor( - eof, dtype=dtypes.bool, name="eof") + self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof") self._timeout = ops.convert_to_tensor( timeout, dtype=dtypes.int64, name="timeout") def _as_variant_tensor(self): - return gen_kafka_ops.kafka_dataset( - self._topics, self._servers, self._group, self._eof, self._timeout) + return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, + self._eof, self._timeout) @property def output_classes(self): diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index fb7b2e315e..1c3af19a6c 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -59,13 +59,12 @@ __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'layer_norm', - 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', - 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', - 'separable_convolution2d', 'softmax', 'spatial_softmax', 'stack', - 'unit_norm', 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', - 'maxout' + 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', 'dropout', + 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'layer_norm', 'linear', + 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', + 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', + 'softmax', 'spatial_softmax', 'stack', 'unit_norm', + 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', 'maxout' ] DATA_FORMAT_NCHW = 'NCHW' @@ -1415,12 +1414,11 @@ def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. """ - with variable_scope.variable_scope( - scope, 'dense_to_sparse', [tensor]) as sc: + with variable_scope.variable_scope(scope, 'dense_to_sparse', [tensor]) as sc: tensor = ops.convert_to_tensor(tensor) indices = array_ops.where( - math_ops.not_equal( - tensor, constant_op.constant(eos_token, tensor.dtype))) + math_ops.not_equal(tensor, constant_op.constant(eos_token, + tensor.dtype))) values = array_ops.gather_nd(tensor, indices) shape = array_ops.shape(tensor, out_type=dtypes.int64) outputs = sparse_tensor.SparseTensor(indices, values, shape) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 8945690db8..972ff10bf9 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1308,12 +1308,13 @@ class DenseToSparseTest(test.TestCase): expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2)) tensor = constant_op.constant(expected_constant) sparse = _layers.dense_to_sparse(tensor) - dense = sparse_ops.sparse_to_dense( - sparse.indices, sparse.dense_shape, sparse.values) + dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, + sparse.values) with self.test_session() as sess: constant = sess.run(dense) self.assertAllEqual(expected_constant, constant) + class DropoutTest(test.TestCase): def testCreateDropout(self): diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index 18bf16e246..ca720ae5ed 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -23,13 +23,11 @@ import csv import os from os import path import random -import tempfile import time import numpy as np from six.moves import urllib -from tensorflow.contrib.framework import deprecated from tensorflow.python.platform import gfile Dataset = collections.namedtuple('Dataset', ['data', 'target']) diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py index d0b9eb8abc..80d4923db3 100644 --- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py +++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.layers import conv2d from tensorflow.contrib.learn.python.learn import ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index 4ae6fb6b94..c7d3b1c966 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlow-experimental' + pod 'TensorFlowLite' diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj index c98183276b..b0236e9c60 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -16,7 +16,6 @@ 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; /* End PBXBuildFile section */ @@ -38,7 +37,6 @@ 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = "<group>"; }; 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = "<group>"; }; AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = "<group>"; }; - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = "<group>"; }; ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = "<group>"; }; /* End PBXFileReference section */ @@ -47,7 +45,6 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, @@ -60,7 +57,6 @@ 24D7686C331131624F4454A0 /* Frameworks */ = { isa = PBXGroup; children = ( - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, @@ -336,7 +332,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; @@ -384,7 +379,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = NO; SDKROOT = iphoneos; TARGETED_DEVICE_FAMILY = "1,2"; diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index d216cdf69b..959347b549 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -43,8 +43,13 @@ cc_library( "label_image.h", ], deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs", ], ) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 471fda2ba4..97343dde6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ #include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h" #include "tensorflow/contrib/lite/examples/label_image/label_image.h" @@ -31,8 +31,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, int wanted_channels, Settings* s); // explicit instantiation -template void resize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int, - int, int, Settings*); +template void resize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int, int, + int, Settings*); template void resize<float>(float*, unsigned char*, int, int, int, int, int, int, Settings*); diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 33ea695dda..f0d81cf7a4 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -13,8 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H -#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -31,7 +37,6 @@ template <class T> void resize(T* out, uint8_t* in, int image_height, int image_width, int image_channels, int wanted_height, int wanted_width, int wanted_channels, Settings* s) { - int number_of_pixels = image_height * image_width * image_channels; std::unique_ptr<Interpreter> interpreter(new Interpreter); @@ -45,7 +50,7 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, interpreter->SetInputs({0, 1}); interpreter->SetOutputs({2}); - // set paramters of tensors + // set parameters of tensors TfLiteQuantizationParams quant; interpreter->SetTensorParametersReadWrite( 0, kTfLiteFloat32, "input", @@ -92,4 +97,4 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, } // namespace label_image } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index a78900122e..a91467d345 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -151,14 +151,14 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize<float>(interpreter->typed_tensor<float>(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + resize<float>(interpreter->typed_tensor<float>(input), in, image_height, + image_width, image_channels, wanted_height, wanted_width, + wanted_channels, s); break; case kTfLiteUInt8: resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in, - image_height, image_width, image_channels, - wanted_height, wanted_width, wanted_channels, s); + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; default: LOG(FATAL) << "cannot handle input type " @@ -188,9 +188,8 @@ void RunInference(Settings* s) { int output = interpreter->outputs()[0]; switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: - get_top_n<float>(interpreter->typed_output_tensor<float>(0), - output_size, num_results, threshold, &top_results, - true); + get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size, + num_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0), diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 9dd60abc86..5aa0cbafd6 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -303,7 +303,6 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index 26cfe6c3e2..fc6594c3a0 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -22,6 +22,7 @@ ifeq ($(TARGET), IOS) IOS_ARCH := x86_64 CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \ -fembed-bitcode \ -Wno-c++11-narrowing \ -mno-thumb \ @@ -42,6 +43,7 @@ ifeq ($(TARGET), IOS) -O3 LDFLAGS := -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -arch $(IOS_ARCH) OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 1fba3cbbce..66d2c04bba 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -463,9 +463,7 @@ TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() { } TfLiteRegistration* Register_CONV_2D() { -// TODO(ycling): Define a compilation flag and use CBLAS kernel when a -// fast CBLAS implementatino is available. -#ifdef TFLITE_USE_CBLAS_CONVOLUTION_KERNEL +#ifdef TFLITE_USE_APPLE_ACCELERATE_FOR_CONV return Register_CONVOLUTION_CBLAS_OPT(); #else return Register_CONVOLUTION_MULTITHREADED_OPT(); diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index adedd58ff4..a6ccc99a51 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -291,7 +291,7 @@ cc_library( "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite/kernels:activation_functor", "@arm_neon_2_x86_sse", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) @@ -325,7 +325,7 @@ cc_library( "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite:builtin_op_data", "@arm_neon_2_x86_sse", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ] + select({ ":arm": [ ":neon_tensor_utils", diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index fdeacedace..18601df22c 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -102,6 +102,17 @@ inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( quantized_multiplier); } +inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, + int shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h index fcb9fac671..4a90e7e640 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h @@ -19,9 +19,12 @@ limitations under the License. // The Conv implementation based on CBLAS interface. This is only used on iOS // for now, utilizing Apple's Accelerate framework. -// TODO(ycling): Update the BUILD file and integrate with Apple Accelerate -// Famework when it's available. +#if TFLITE_USE_APPLE_ACCELERATE_FOR_CONV +#include <Accelerate/Accelerate.h> +#else #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h" +#endif + #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h index e0eca2e736..3a53d3ab07 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -34,7 +34,7 @@ inline bool TestCPUFeatureNeon() { #endif // __aarch64__ } -#elif defined USE_NEON || defined __ARM_NEON +#elif defined USE_NEON || defined __ARM_NEON inline bool TestCPUFeatureNeon() { return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index ea8502ae33..780401e052 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" #ifdef USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index d5b0f45fd8..cd52385f41 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2081,6 +2081,166 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } +// Quantized LSTM cell. Currently just a copy of the reference impl in +// reference_ops.h. See the big function comment there, not replicating it +// here. +template <int StateIntegerBits> +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift) { + gemmlowp::ScopedProfilingLabel label( + "LstmCell/quantized (8bit external, 16bit internal)"); + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation<FusedActivationFunctionType::kNone, uint8>( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + for (int b = 0; b < fc_batches; ++b) { + for (int out_c = 0; out_c < fc_output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data_int32[out_c]; + // Accumulation loop. + for (int d = 0; d < fc_accum_depth; ++d) { + int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; + int16 weights_val = + weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; + accum += input_val * weights_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, using 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + // Note that the implicit assumption here, that this multiplier is smaller + // than one, is equivalent to the assumption that the fully-connected + // weights min-max is enclosed within [-4, 4] (it may be narrower). + // If that eventually fails, offline tools (e.g. toco) will fail early + // and that will be easy to support as needed. For now, assuming that + // this multiplier is less than one allows us to use a simpler, more + // accurate implementation. + accum = + MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); + // Saturate, cast to int16, and store to the temporary activations array. + accum = std::max(-32768, std::min(32767, accum)); + activ_temp_data_int16[out_c + fc_output_depth * b] = accum; + } + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + for (int b = 0; b < outer_size; ++b) { + for (int c = 0; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint<std::int16_t, 0>; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint<std::int16_t, 3>; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ)); + output_activ_data_uint8[b * output_depth + c] = + 128 + clamped_output_activ; + } + } +} + template <FusedActivationFunctionType Ac, typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2942,51 +3102,152 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, uint8* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int height = MatchingArraySize(input_dims, 2, output_dims, 2); - const int width = MatchingArraySize(input_dims, 1, output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; - const int32 input_val_centered = - static_cast<int32>(input_val_u8) - input_zero_point; - uint8 output_val; - if (input_val_centered <= -input_range_radius) { - output_val = 0; - } else if (input_val_centered >= input_range_radius) { - output_val = 255; - } else { - const int32 input_val_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_val_centered, input_multiplier, input_left_shift); - using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>; - using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>; - const FixedPoint4 input_val_f4 = - FixedPoint4::FromRaw(input_val_rescaled); - const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); - - using gemmlowp::RoundingDivideByPOT; - int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); - // TODO(mjmatthews): properly wire through this zero offset - output_val_s32 += 127; - if (output_val_s32 == -1) { - // May underflow since we cannot properly represent -1.0f - output_val_s32 = 0; - } - TFLITE_DCHECK_GE(output_val_s32, 0); - TFLITE_DCHECK_LE(output_val_s32, 255); - output_val = static_cast<uint8>(output_val_s32); - } - output_data[Offset(output_dims, c, x, y, b)] = output_val; - } + // Note that this is almost the exact same code as in Logistic(). + gemmlowp::ScopedProfilingLabel label("Tanh"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; + int32_t output_zero_point = 128; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>; + using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3); + + // Divide by 2^24 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24); + + // Add the output zero point + int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point); + output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32); + output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32); + output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32); + output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast<int32>(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>; + using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast<uint8>(output_val_s32); } + output_data[c] = output_val; } } - inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index 98f2e365c5..18be6777a5 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -22,27 +22,20 @@ limitations under the License. namespace tflite { -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { - TFLITE_CHECK(double_multiplier >= 0.); - TFLITE_CHECK(double_multiplier < 1.); +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift) { if (double_multiplier == 0.) { *quantized_multiplier = 0; - *right_shift = 0; + *shift = 0; return; } - TFLITE_CHECK(double_multiplier > 0.); - const double q = std::frexp(double_multiplier, right_shift); - *right_shift *= -1; - + const double q = std::frexp(double_multiplier, shift); auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31))); TFLITE_CHECK(q_fixed <= (1ll << 31)); if (q_fixed == (1ll << 31)) { q_fixed /= 2; - --*right_shift; + ++*shift; } - TFLITE_CHECK_GE(*right_shift, 0); TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max()); *quantized_multiplier = static_cast<int32_t>(q_fixed); } @@ -50,17 +43,20 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift) { - TFLITE_CHECK(double_multiplier > 1.); - const double q = std::frexp(double_multiplier, left_shift); - auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31))); - TFLITE_CHECK(q_fixed <= (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*left_shift; - } + TFLITE_CHECK_GT(double_multiplier, 1.); + QuantizeMultiplier(double_multiplier, quantized_multiplier, left_shift); TFLITE_CHECK_GE(*left_shift, 0); - TFLITE_CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max()); - *quantized_multiplier = static_cast<int32_t>(q_fixed); +} + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK_LT(double_multiplier, 1.); + TFLITE_CHECK_GT(double_multiplier, 0.); + int shift; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + TFLITE_CHECK_LE(shift, 0); + *right_shift = -shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index efb7191c8d..ba06bc0975 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -20,7 +20,8 @@ limitations under the License. namespace tflite { // Decompose a double multiplier into a Q0.31 int32 representation of its -// significand, and shift representation of its exponent. +// significand, and shift representation of NEGATIVE its exponent --- +// this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). void QuantizeMultiplierSmallerThanOne(double double_multiplier, @@ -35,6 +36,16 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift); +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Handles an arbitrary positive multiplier. The 'shift' output-value is +// basically the 'floating-point exponent' of the multiplier: +// Negative for a right-shift (when the multiplier is <1), positive for a +// left-shift (when the multiplier is >1) +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift); + // This first creates a multiplier in a double equivalent of // Q(input_integer_bits).(31-input_integer_bits) representation, with extra // precision in the double's fractional bits. It then splits the result into diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index d6f306e2cb..19b1b408ec 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -31,7 +31,7 @@ TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { }; EXPECT_DEATH(quantize(-0.1), ""); - EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_DEATH(quantize(0.0), ""); EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); // Around 0.5 we can see the change in exponent and how we try hard to diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 40e5c48a4c..f18543f4e4 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1358,6 +1358,238 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, } } +// Quantized LSTM cell implementation. +// The quantization of the input, output arrays is as follows: +// - The input activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that that is the natural interval for output +// activations (see next point) and these need to be concatenated together. +// We could accommodate different ranges by re-scaling, but we empirically +// found that setting the input activations range to be [-1, 127/128] in the +// first place, removing the need for re-scaling, greatly improves accuracy. +// - The output activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that the definition of a LSTM cell makes them +// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128] +// makes for simpler, more accurate fixed-point arithmetic. +// - The output-at-previous-timestep state array is obviously quantized as +// the output activations. +// - The internal LSTM memory (not the output-at-previous-timestep, the other +// internal state array) is int16-quantized and may use any power-of-two, +// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call +// StateIntegerBits below, see the below discussion of that template +// parameter ("The StateIntegerBits template parameter"). +// - The output of the internal fully-connected node is int16-quantized +// on the interval [-8, 8 * 32767/32768], the rationale for which is +// explained just below ("Why [-8, 8] for fully-connected output?"). +// +// +// === The StateIntegerBits template parameter === +// +// The StateIntegerBits template parameter controls the fixed-point format used +// to represent the internal memory of the LSTM cell (not the +// output-at-previous-timestep, the other internal state array). It's currently +// a template parameter so that the model can control that. The most typical +// value for StateIntegerBits is 4. Other plausible values are anywhere between +// 3 and 5. We might eventually standardize on a single supported value, e.g. 4, +// and drop that template parameter. The reason why it can't be a runtime +// parameter is that this controls the fixed-point format used, i.e. we need to +// generate actually different code based on it. In particular, we generate code +// for a fixed-point tanh() implementation for that format, which internally +// uses a fixed-point exp() implementation, which internally uses a +// barrel-shifter with a number of steps that depends on StateIntegerBits. +// Another consequence of that is that a higher value of StateIntegerBits +// results in a more expensive implementation (more barrel shifter steps +// needed). +// +// +// === Why [-8, 8] for fully-connected output? === +// +// This array is only fed to Logistic and Tanh functions, for which +// the quantized implementation will want to use fixed-point arithmetic, +// requiring a power-of-two representation interval. Thus, we should right +// away quantize this array to a power-of-two interval; otherwise, +// implementation will need to rescale that, losing any benefit that a tighter +// representation interval might otherwise yield, while introducting some +// numerical error and computational overhead. +// +// Now, Logistic and Tanh +// are nearly constant (nearly equal to their horizontal asymptotes) +// outside of a small bounded interval around 0: +// +// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4 +// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7 +// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14 +// +// From this, we see that clamping to [-4, 4] would be too inaccurate +// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision) +// while clamping to [-16, 16] would make no difference even in float32. +// However, for a fixed-point implementation in 16-bit integers, using 5 +// integer bits to represent the [-16, 16] range would leave only 11 +// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive +// representable values. Notice that that is higher than the +// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic. +// Using [-8, 8] thus seems like the better compromise overall, enjoying +// an increment of 2.4e-4 between representable values and a worst-case +// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with +// [-16, 16]. +// +// Moreover, all other things being equal, it is nice to choose the narrower +// representation range, as that makes the implementation of fixed-point +// math functions a little cheaper (each integer bit requires an additional +// barrel-shifter atep in the implementation of exp(-x)). That is further +// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make +// sense for 32-bit float or 32-bit fixed-point quantization, but we are +// aiming for 16-bit fixed-point quantization of these internal nodes here. +// +template <int StateIntegerBits> +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift) { + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation<FusedActivationFunctionType::kNone, uint8>( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + for (int b = 0; b < fc_batches; ++b) { + for (int out_c = 0; out_c < fc_output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data_int32[out_c]; + // Accumulation loop. + for (int d = 0; d < fc_accum_depth; ++d) { + int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; + int16 weights_val = + weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; + accum += input_val * weights_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, using 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + accum = + MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); + // Saturate, cast to int16, and store to the temporary activations array. + accum = std::max(-32768, std::min(32767, accum)); + activ_temp_data_int16[out_c + fc_output_depth * b] = accum; + } + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + for (int b = 0; b < outer_size; ++b) { + for (int c = 0; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint<std::int16_t, 0>; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint<std::int16_t, 3>; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ)); + output_activ_data_uint8[b * output_depth + c] = + 128 + clamped_output_activ; + } + } +} + template <FusedActivationFunctionType Ac, typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2047,6 +2279,7 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, uint8* output_data, const Dims<4>& output_dims) { + const int32 output_zero_point = 128; const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int height = MatchingArraySize(input_dims, 2, output_dims, 2); const int width = MatchingArraySize(input_dims, 1, output_dims, 1); @@ -2075,11 +2308,9 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, using gemmlowp::RoundingDivideByPOT; int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); - // TODO(mjmatthews): properly wire through this zero offset - output_val_s32 += 127; - if (output_val_s32 == -1) { - // May underflow since we cannot properly represent -1.0f - output_val_s32 = 0; + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; } TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 20c156a932..45031de09c 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -174,6 +174,7 @@ cc_library( "graph_transformations/convert_pure_conv_to_depthwise.cc", "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_trivial_addn_to_add.cc", + "graph_transformations/convert_trivial_stack_to_reshape.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -188,7 +189,10 @@ cc_library( "graph_transformations/identify_l2_normalization.cc", "graph_transformations/identify_l2_pool.cc", "graph_transformations/identify_lstm.cc", + "graph_transformations/identify_lstm_merge_inputs.cc", + "graph_transformations/identify_lstm_split_inputs.cc", "graph_transformations/identify_relu1.cc", + "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_fixed_sizes.cc", @@ -204,6 +208,7 @@ cc_library( "graph_transformations/remove_trivial_passthrough.h", "graph_transformations/remove_trivial_quantized_activation_func.cc", "graph_transformations/remove_trivial_reshape.cc", + "graph_transformations/remove_trivial_slice.cc", "graph_transformations/remove_unused_op.cc", "graph_transformations/reorder_activation_functions.cc", "graph_transformations/resolve_batch_normalization.cc", @@ -216,6 +221,7 @@ cc_library( "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", + "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_mean_attributes.cc", "graph_transformations/resolve_pad_attributes.cc", @@ -232,9 +238,11 @@ cc_library( "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", "graph_transformations/unfuse_activation_functions.cc", + "graph_transformations/unroll_batch_matmul.cc", ], hdrs = [ "graph_transformations/graph_transformations.h", + "graph_transformations/lstm_utils.h", ], visibility = ["//visibility:public"], deps = [ @@ -245,6 +253,7 @@ cc_library( ":tooling_util", ":types_proto_cc", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index be6d506bf3..70d7a9d4a5 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -46,6 +46,32 @@ using tensorflow::TensorProto; namespace toco { namespace { +tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kBool: + return tensorflow::DT_BOOL; + case ArrayDataType::kFloat: + return tensorflow::DT_FLOAT; + case ArrayDataType::kUint8: + return tensorflow::DT_UINT8; + case ArrayDataType::kInt32: + return tensorflow::DT_INT32; + case ArrayDataType::kInt64: + return tensorflow::DT_INT64; + case ArrayDataType::kString: + return tensorflow::DT_STRING; + default: + case ArrayDataType::kNone: + LOG(FATAL) << "Unsupported data type: " << static_cast<int>(data_type); + return tensorflow::DT_INVALID; + } +} + +tensorflow::DataType GetTensorFlowDataType(const Model& model, + const string& array_name) { + return GetTensorFlowDataType(model.GetArray(array_name).data_type); +} + // TensorFlow sometimes forbids what it calls "legacy scalars", // which are 1-D shapes where the unique shape size is 1. // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. @@ -212,6 +238,24 @@ void ConvertIntTensorConst(const Model& model, const string& name, } } +void CreateIntTensorConst(const string& name, const std::vector<int32>& data, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + for (auto index : data) { + tensor->add_int_val(index); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(data.size()); +} + void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -445,14 +489,23 @@ void ConvertSpaceToDepthOperator(const Model& model, void ConvertFullyConnectedOperator(const Model& model, const FullyConnectedOperator& src_op, GraphDef* tensorflow_graph) { - const string reshape_output = src_op.outputs[0] + "/reshape"; - const string reshape_shape = src_op.outputs[0] + "/reshape/shape"; + // Reshape input activations to have the shape expected by the MatMul. + const string reshape_output = + AvailableArrayName(model, src_op.outputs[0] + "/reshape"); + const string reshape_shape = + AvailableArrayName(model, reshape_output + "/shape"); + const auto& fc_weights_array = model.GetArray(src_op.inputs[1]); + const auto& fc_weights_shape = fc_weights_array.shape(); + CHECK_EQ(fc_weights_shape.dimensions_count(), 2); + CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, + tensorflow_graph); auto* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); reshape_op->add_input(src_op.inputs[0]); reshape_op->add_input(reshape_shape); - (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*reshape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); const bool has_bias = src_op.inputs.size() >= 3; string matmul_output = src_op.outputs[0]; @@ -460,38 +513,43 @@ void ConvertFullyConnectedOperator(const Model& model, matmul_output += "/matmul"; } + // Transpose the RHS input from column-major to row-major to match TensorFlow + // expectations. This is the inverse of the transpose we do during + // ResolveTensorFlowMatMul. + const string transpose_output = + AvailableArrayName(model, matmul_output + "/transpose_weights"); + const string transpose_perm = + AvailableArrayName(model, transpose_output + "/perm"); + CreateIntTensorConst(transpose_perm, {1, 0}, tensorflow_graph); + auto transpose_op = tensorflow_graph->add_node(); + transpose_op->set_op("Transpose"); + transpose_op->set_name(transpose_output); + *transpose_op->add_input() = src_op.inputs[1]; + *transpose_op->add_input() = transpose_perm; + (*transpose_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32); + auto* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); - matmul_op->set_name(matmul_output); *matmul_op->add_input() = reshape_output; - *matmul_op->add_input() = src_op.inputs[1]; - (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + *matmul_op->add_input() = transpose_op->name(); + (*matmul_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); (*matmul_op->mutable_attr())["transpose_a"].set_b(false); (*matmul_op->mutable_attr())["transpose_b"].set_b(false); CHECK(model.HasArray(src_op.inputs[1])); - const string& fc_weights_name = - WalkUpToConstantArray(model, src_op.inputs[1]); - const auto& fc_weights_array = model.GetArray(fc_weights_name); - const auto& fc_weights_shape = fc_weights_array.shape(); - CHECK_EQ(fc_weights_shape.dimensions_count(), 2); - CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, - tensorflow_graph); - - CHECK(fc_weights_array.buffer); - CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat); - const float* fc_weights_data = - fc_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data(); - ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data, - AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + // Add the bias, if it exists. if (has_bias) { auto* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(matmul_output); biasadd_op->add_input(src_op.inputs[2]); - (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*biasadd_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); CHECK(model.HasArray(src_op.inputs[2])); const auto& bias_array = model.GetArray(src_op.inputs[2]); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. @@ -657,6 +715,45 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertLogSoftmaxOperator(const Model& model, + const LogSoftmaxOperator& src_op, + GraphDef* tensorflow_graph) { + string softmax_input; + Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); + if (providing_op->type == OperatorType::kTensorFlowReshape) { + softmax_input = src_op.inputs[0]; + } else { + // Insert a reshape operator that reduces the dimensions down to the 2 that + // are required for TensorFlow Logits. + const string reshape_output = + src_op.outputs[0] + "/log_softmax_insert_reshape"; + const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size"; + softmax_input = reshape_output; + + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = softmax_size; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); + int32 flattened_size = 1; + for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { + flattened_size *= input_shape.dims(i); + } + const std::vector<int32> shape_data = { + flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; + CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); + } + + auto* log_softmax_op = tensorflow_graph->add_node(); + log_softmax_op->set_op("LogSoftmax"); + log_softmax_op->set_name(src_op.outputs[0]); + *log_softmax_op->add_input() = softmax_input; + (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, GraphDef* tensorflow_graph) { const string square_output = src_op.outputs[0] + "/square"; @@ -799,7 +896,8 @@ void ConvertConcatenationOperator(const Model& model, *dc_op->add_input() = input; } *dc_op->add_input() = dummy_axis; - (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*dc_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32); (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size()); } @@ -813,7 +911,8 @@ void ConvertTensorFlowReshapeOperator(const Model& model, CHECK_EQ(src_op.inputs.size(), 2); *reshape_op->add_input() = src_op.inputs[0]; *reshape_op->add_input() = src_op.inputs[1]; - (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*reshape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); const auto& shape_array = model.GetArray(src_op.inputs[1]); QCHECK(shape_array.data_type == ArrayDataType::kInt32) << "Only int32 shape is supported."; @@ -910,24 +1009,6 @@ void ConvertSplitOperator(const Model& model, tensorflow_graph); } -tensorflow::DataType GetTensorFlowDataType(const Model& model, - const string& array_name) { - auto& dtype = model.GetArray(array_name).data_type; - CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 || - dtype == ArrayDataType::kUint8 || dtype == ArrayDataType::kInt64); - if (dtype == ArrayDataType::kFloat) { - return tensorflow::DT_FLOAT; - } else if (dtype == ArrayDataType::kInt32) { - return tensorflow::DT_INT32; - } else if (dtype == ArrayDataType::kUint8) { - return tensorflow::DT_UINT8; - } else if (dtype == ArrayDataType::kInt64) { - return tensorflow::DT_INT64; - } else { - LOG(FATAL) << "Wrong data type"; - } -} - void ConvertCastOperator(const Model& model, const CastOperator& src_op, GraphDef* tensorflow_graph) { auto* cast_op = tensorflow_graph->add_node(); @@ -982,6 +1063,113 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertTransposeOperator(const Model& model, + const TransposeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* transpose_op = tensorflow_graph->add_node(); + transpose_op->set_op("Transpose"); + transpose_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *transpose_op->add_input() = src_op.inputs[0]; + *transpose_op->add_input() = src_op.inputs[1]; + (*transpose_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*transpose_op->mutable_attr())["Tperm"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + +void ConvertTensorFlowShapeOperator(const Model& model, + const TensorFlowShapeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* shape_op = tensorflow_graph->add_node(); + shape_op->set_op("Shape"); + shape_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *shape_op->add_input() = src_op.inputs[0]; + (*shape_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*shape_op->mutable_attr())["out_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + +void ConvertRankOperator(const Model& model, const RankOperator& src_op, + GraphDef* tensorflow_graph) { + auto* rank_op = tensorflow_graph->add_node(); + rank_op->set_op("Rank"); + rank_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rank_op->add_input() = src_op.inputs[0]; + (*rank_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* range_op = tensorflow_graph->add_node(); + range_op->set_op("Range"); + range_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *range_op->add_input() = src_op.inputs[0]; + *range_op->add_input() = src_op.inputs[1]; + *range_op->add_input() = src_op.inputs[2]; + (*range_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(src_op.dtype)); +} + +void ConvertStackOperator(const Model& model, const StackOperator& src_op, + GraphDef* tensorflow_graph) { + auto* stack_op = tensorflow_graph->add_node(); + stack_op->set_op("Stack"); + stack_op->set_name(src_op.outputs[0]); + for (const auto& input : src_op.inputs) { + *stack_op->add_input() = input; + } + (*stack_op->mutable_attr())["elem_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*stack_op->mutable_attr())["axis"].set_i(src_op.axis); +} + +void ConvertFillOperator(const Model& model, const FillOperator& src_op, + GraphDef* tensorflow_graph) { + auto* fill_op = tensorflow_graph->add_node(); + fill_op->set_op("Fill"); + fill_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *fill_op->add_input() = src_op.inputs[0]; + *fill_op->add_input() = src_op.inputs[1]; + (*fill_op->mutable_attr())["index_type"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*fill_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + +void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, + GraphDef* tensorflow_graph) { + auto* floor_div_op = tensorflow_graph->add_node(); + floor_div_op->set_op("FloorDiv"); + floor_div_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *floor_div_op->add_input() = src_op.inputs[0]; + *floor_div_op->add_input() = src_op.inputs[1]; + (*floor_div_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertExpandDimsOperator(const Model& model, + const ExpandDimsOperator& src_op, + GraphDef* tensorflow_graph) { + auto* expand_dims_op = tensorflow_graph->add_node(); + expand_dims_op->set_op("ExpandDims"); + expand_dims_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *expand_dims_op->add_input() = src_op.inputs[0]; + *expand_dims_op->add_input() = src_op.inputs[1]; + (*expand_dims_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*expand_dims_op->mutable_attr())["Tdim"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); +} + void ConvertResizeBilinearOperator(const Model& model, const ResizeBilinearOperator& src_op, GraphDef* tensorflow_graph) { @@ -1447,6 +1635,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSoftmax) { ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kLogSoftmax) { + ConvertLogSoftmaxOperator(model, + static_cast<const LogSoftmaxOperator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kLocalResponseNormalization) { ConvertLocalResponseNormalizationOperator( static_cast<const LocalResponseNormalizationOperator&>(src_op), @@ -1535,6 +1727,32 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTranspose) { + ConvertTransposeOperator( + model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowShape) { + ConvertTensorFlowShapeOperator( + model, static_cast<const TensorFlowShapeOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRank) { + ConvertRankOperator(model, static_cast<const RankOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRange) { + ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kStack) { + ConvertStackOperator(model, static_cast<const StackOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFill) { + ConvertFillOperator(model, static_cast<const FillOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFloorDiv) { + ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kExpandDims) { + ConvertExpandDimsOperator(model, + static_cast<const ExpandDimsOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -1624,6 +1842,30 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, } } // namespace +void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) { + for (const auto& array_kv : model->GetArrayMap()) { + const string& array_name = array_kv.first; + Array& array = *array_kv.second; + if (!array.buffer || !array.minmax) { + continue; + } + const string& wrapped_array_name = + AvailableArrayName(*model, array_name + "/data"); + Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name); + wrapped_array.data_type = array.data_type; + wrapped_array.copy_shape(array.shape()); + wrapped_array.buffer = std::move(array.buffer); + FakeQuantOperator* fakequant_op = new FakeQuantOperator; + fakequant_op->inputs = {wrapped_array_name}; + fakequant_op->outputs = {array_name}; + fakequant_op->minmax.reset(new MinMax); + *fakequant_op->minmax = *array.minmax; + const auto& it = FindOpWithInput(*model, array_name); + model->operators.emplace(it, fakequant_op); + } + CheckInvariants(*model); +} + void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents) { CHECK(output_file_contents->empty()); diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h index 79682153a8..d7310bb75f 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.h +++ b/tensorflow/contrib/lite/toco/export_tensorflow.h @@ -22,6 +22,8 @@ namespace toco { void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents); +void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc new file mode 100644 index 0000000000..0615b5e6c6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { + auto stack_it = model->operators.begin() + op_index; + if (stack_it->get()->type != OperatorType::kStack) { + return false; + } + auto* stack_op = static_cast<StackOperator*>(stack_it->get()); + if (stack_op->inputs.size() > 1) { + // Not trivial. + return false; + } + CHECK_EQ(stack_op->outputs.size(), 1); + + const auto& input_array = model->GetArray(stack_op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return false; + } + if (input_array.shape().dimensions_count() == 0) { + // Input array cannot be 0-D. + // (Unsure if this is TF behavior, but was required to get a test to pass.) + return false; + } + + AddMessageF("Converting trivial %s to a reshape", LogName(*stack_op)); + + // Note that we could convert to ExpandDims but toco prefers reshapes. + auto* reshape_op = new TensorFlowReshapeOperator; + reshape_op->inputs = {stack_op->inputs[0]}; + reshape_op->outputs = stack_op->outputs; + + // Create shape param. + string shape_array_name = + AvailableArrayName(*model, stack_op->outputs[0] + "_shape"); + Array& shape_array = model->GetOrCreateArray(shape_array_name); + *(shape_array.mutable_shape()->mutable_dims()) = { + 1 + input_array.shape().dimensions_count()}; + reshape_op->inputs.push_back(shape_array_name); + shape_array.data_type = ArrayDataType::kInt32; + auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>(); + shape_buffer.data.push_back(1); + for (int dim : input_array.shape().dims()) { + shape_buffer.data.push_back(dim); + } + + // Replace the operator in the graph. + const auto reshape_it = model->operators.emplace(stack_it, reshape_op); + stack_it = reshape_it + 1; + CHECK_EQ(stack_it->get(), stack_op); + model->operators.erase(stack_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index cf90ebe996..3ab01ae643 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -115,6 +115,7 @@ void RunGraphTransformations(Model* model, const string& message, DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) @@ -124,6 +125,8 @@ DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) +DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) +DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) @@ -136,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) @@ -154,8 +158,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index f1892136cf..1b0be85810 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -177,6 +177,106 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, output_minmax.max = max; return true; } + +// Propagates MinMax from any of the listed arrays, to all others. +// If multiple of these arrays have MinMax, then these are required +// to agree with each other. +bool PropagateMinMaxAmongArrays(Model* model, + const std::vector<string> array_names) { + string reference_array_name; + MinMax* reference_minmax = nullptr; + for (const string& array_name : array_names) { + if (model->GetArray(array_name).minmax) { + reference_array_name = array_name; + reference_minmax = model->GetArray(array_name).minmax.get(); + break; + } + } + // No MinMax info is available to propagate. + if (!reference_minmax) { + return false; + } + bool changed = false; + for (const string& array_name : array_names) { + auto& array = model->GetArray(array_name); + if (array.minmax) { + CHECK(*array.minmax == *reference_minmax) + << "Both the following arrays have minmax, and they disagree: " + << reference_array_name << " and " << array_name + << ". Expected that either only one of them would have minmax, or at " + "least that they would agree."; + } else { + array.GetOrCreateMinMax() = *reference_minmax; + changed = true; + } + } + return changed; +} + +bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) { + CHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); + CHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + + bool changed = false; + changed |= PropagateMinMaxAmongArrays( + model, {op->inputs[LstmCellOperator::PREV_STATE_INPUT], + op->outputs[LstmCellOperator::STATE_OUTPUT]}); + + auto& input_activations = + model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); + if (!input_activations.minmax) { + auto& minmax = input_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& prev_output_activations = + model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]); + if (!prev_output_activations.minmax) { + auto& minmax = prev_output_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& output_concat_temp = + model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]); + if (!output_concat_temp.minmax) { + auto& minmax = output_concat_temp.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + auto& output_activations = + model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]); + if (!output_activations.minmax) { + auto& minmax = output_activations.GetOrCreateMinMax(); + minmax.min = -1; + minmax.max = 127. / 128.; + changed = true; + } + + // (This comment should morph into proper documentation for + // quantization of LSTM models. It isn't just a local implementation detail, + // the training code for LSTM models needs to be adjusted to that.) + // + // Finally, output_activations_temp holds the output of the fully-connected + // node inside the LSTM cell. For it, we hardcode a minmax of [-8, 8]. + // The rationale for that is given in a lengthy comment on the LstmCell + // quantized runtime implementation in reference_ops.h. + auto& output_activations_temp = + model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]); + if (!output_activations_temp.minmax) { + auto& minmax = output_activations_temp.GetOrCreateMinMax(); + minmax.min = -8; + minmax.max = 8 * 32767. / 32768.; + changed = true; + } + + return changed; +} } // namespace bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { @@ -225,6 +325,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForOutput(model, op, -127. / 128., 1.0); break; + case OperatorType::kLstmCell: + changed = HardcodeMinMaxForLstmCell(model, op); + break; + default: break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc new file mode 100644 index 0000000000..45335fd78c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { + // Find lstm cell. + auto op_it = model->operators.begin() + op_index; + auto src_op = op_it->get(); + if (src_op->type != OperatorType::kLstmCell) { + return false; + } + + // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, + // do not need to merge cell inputs. + if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + return false; + } + + // Identify prev_activ_input, prev_state_input as required Op inputs, + // using the rnn_states in the model flag. + string prev_activ_input; + if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor], + &prev_activ_input)) { + return false; + } + string prev_state_input; + if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor], + &prev_state_input)) { + return false; + } + + // Get LstmCell's cell, input, output size. + int num_cell = model->GetArray(src_op->inputs[kInputToInputWeightsTensor]) + .shape() + .dims(0); + int num_input = model->GetArray(src_op->inputs[kInputToInputWeightsTensor]) + .shape() + .dims(1); + int num_output = + model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor]) + .shape() + .dims(1); + + // Make sure n_cell and n_output are equal as there is no projection. + CHECK_EQ(num_cell, num_output); + + // Create tensorflow_graphdef style's one big weight tensor. + const string base_name(FindLongestCommonPrefix( + src_op->outputs[kOutputTensor], src_op->outputs[kCellStateTensor])); + string merged_weights = AvailableArrayName(*model, base_name + "weights"); + auto& array = model->GetOrCreateArray(merged_weights); + array.data_type = ArrayDataType::kFloat; + int weights_dim1 = 4 * num_cell; + int weights_dim2 = num_input + num_output; + Shape shape = Shape({weights_dim1, weights_dim2}); + array.copy_shape(shape); + auto& buffer = array.GetMutableBuffer<ArrayDataType::kFloat>(); + buffer.data.resize(weights_dim1 * weights_dim2); + + // Merge 8 small weight tensors to 1 weight tensor. + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToInputWeightsTensor]), 0, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToCellWeightsTensor]), num_cell, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToForgetWeightsTensor]), + num_cell * 2, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputToOutputWeightsTensor]), + num_cell * 3, 0); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToInputWeightsTensor]), 0, + num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToCellWeightsTensor]), num_cell, + num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToForgetWeightsTensor]), + num_cell * 2, num_input); + CopyArrayToSubArray( + buffer, weights_dim2, + model->GetArray(src_op->inputs[kRecurrentToOutputWeightsTensor]), + num_cell * 3, num_input); + + // Create tensorflow_graphdef style's one big bias tensor. + string merged_biases = AvailableArrayName(*model, base_name + "biases"); + auto& bias_array = model->GetOrCreateArray(merged_biases); + bias_array.data_type = ArrayDataType::kFloat; + bias_array.copy_shape(Shape({weights_dim1})); + auto& bias_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>(); + bias_buffer.data.resize(weights_dim1); + + // Merge 4 small bias tensors into a big one. + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kInputGateBiasTensor]), 0, + 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kCellGateBiasTensor]), + num_cell, 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kForgetGateBiasTensor]), + num_cell * 2, 0); + CopyArrayToSubArray(bias_buffer, weights_dim2, + model->GetArray(src_op->inputs[kOutputGateBiasTensor]), + num_cell * 3, 0); + + // Emplace a new LSTM cell operator (use basic 5 inputs kernel). + auto lstm_cell_op = absl::make_unique<LstmCellOperator>(); + + // Compact LstmCell's 5 inputs. + lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); + lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = + src_op->inputs[kInputTensor]; + lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = merged_weights; + lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = merged_biases; + lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = prev_activ_input; + lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state_input; + + // Reorder LstmCell's 4 outputs. + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = + src_op->outputs[kOutputTensor]; + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = + src_op->outputs[kCellStateTensor]; + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = + src_op->outputs[kScratchBufferTensor]; + lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = + src_op->outputs[kOutputStateTensor]; + + // Add the op into model. + model->operators.emplace(op_it, std::move(lstm_cell_op)); + AddMessageF("Creating compact LstmCell replacing previous lstm cell"); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + // Erase curr lstm op being replaced. + DeleteArrayIfUnused(src_op->inputs[kInputToInputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToForgetWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToCellWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputToOutputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToInputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToForgetWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToCellWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kRecurrentToOutputWeightsTensor], model); + DeleteArrayIfUnused(src_op->inputs[kInputGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kForgetGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kCellGateBiasTensor], model); + DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model); + model->operators.erase(FindOp(*model, src_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc new file mode 100644 index 0000000000..eca717680a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -0,0 +1,171 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { + // Find lstm cell. + auto op_it = model->operators.begin() + op_index; + auto curr_op = op_it->get(); + if (curr_op->type != OperatorType::kLstmCell) { + return false; + } + + // Already an extended LstmCell with kExtendedLstmInputCount of inputs, + // do not need to split cell inputs. + if (curr_op->inputs.size() == kExtendedLstmInputCount) { + return false; + } + + // Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays, + // that are able to be split into smaller weight and bias tensors. + if (!IsConstantParameterArray( + *model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) || + !IsConstantParameterArray( + *model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) { + return false; + } + + // Make sure propagate_fixed_sizes has defined the size of the output. + if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .has_shape()) { + return false; + } + + // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). + auto lstm_cell_op = absl::make_unique<LstmCellOperator>(); + lstm_cell_op->inputs.resize(kExtendedLstmInputCount); + int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) + .shape() + .dims(1); + + // n_cell and n_output have the same size when there is no projection. + int num_cell = + model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .shape() + .dims(1); + int num_output = num_cell; + + // Data input. + lstm_cell_op->inputs[kInputTensor] = + curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. + Array& kernel = + model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); + const string base_name(FindLongestCommonPrefix( + curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT], + curr_op->outputs[LstmCellOperator::STATE_OUTPUT])); + + // Input weight tensors of size {n_cell, n_input}. + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToInputWeightsTensor]), + base_name + "weight_i_i", num_cell, num_input, kernel, 0, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputToCellWeightsTensor]), + base_name + "weight_c_i", num_cell, num_input, kernel, + num_cell, 0); + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToForgetWeightsTensor]), + base_name + "weight_f_i", num_cell, num_input, kernel, num_cell * 2, 0); + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kInputToOutputWeightsTensor]), + base_name + "weight_o_i", num_cell, num_input, kernel, num_cell * 3, 0); + + // Recurrent weight tensors of size {n_cell, n_output}. + CopySubArrayToArray( + model, &(lstm_cell_op->inputs[kRecurrentToInputWeightsTensor]), + base_name + "weight_i_r", num_cell, num_output, kernel, 0, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToCellWeightsTensor]), + base_name + "weight_c_r", num_cell, num_output, kernel, + num_cell, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToForgetWeightsTensor]), + base_name + "weight_f_r", num_cell, num_output, kernel, + num_cell * 2, num_input); + CopySubArrayToArray(model, + &(lstm_cell_op->inputs[kRecurrentToOutputWeightsTensor]), + base_name + "weight_o_r", num_cell, num_output, kernel, + num_cell * 3, num_input); + + // Peephole (optional). + CreateOptionalArray(model, &(lstm_cell_op->inputs[kCellToInputWeightsTensor]), + base_name + "peephole_c_i"); + CreateOptionalArray(model, + &(lstm_cell_op->inputs[kCellToForgetWeightsTensor]), + base_name + "peephole_c_f"); + CreateOptionalArray(model, + &(lstm_cell_op->inputs[kCellToOutputWeightsTensor]), + base_name + "peephole_c_o"); + + // Get original bias tensor and decompose 1 tensor to 4 sub tensors + Array& bias = + model->GetArray(curr_op->inputs[LstmCellOperator::BIASES_INPUT]); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kInputGateBiasTensor]), + base_name + "bias_i", num_cell, 1, bias, 0, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kCellGateBiasTensor]), + base_name + "bias_c", num_cell, 1, bias, num_cell, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kForgetGateBiasTensor]), + base_name + "bias_f", num_cell, 1, bias, num_cell * 2, 0); + CopySubArrayToArray(model, &(lstm_cell_op->inputs[kOutputGateBiasTensor]), + base_name + "bias_o", num_cell, 1, bias, num_cell * 3, 0); + + // Projection (optional). + CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionWeightsTensor]), + base_name + "proj_weight"); + CreateOptionalArray(model, &(lstm_cell_op->inputs[kProjectionBiasTensor]), + base_name + "proj_bias"); + + // Reorder LstmCell's outputs. + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[kScratchBufferTensor] = + curr_op->outputs[LstmCellOperator::CONCAT_TEMP]; + lstm_cell_op->outputs[kOutputStateTensor] = + curr_op->outputs[LstmCellOperator::ACTIV_TEMP]; + lstm_cell_op->outputs[kCellStateTensor] = + curr_op->outputs[LstmCellOperator::STATE_OUTPUT]; + lstm_cell_op->outputs[kOutputTensor] = + curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]; + + // Add the op into model. + model->operators.emplace(op_it, std::move(lstm_cell_op)); + AddMessageF("Creating extended LstmCell replacing previous lstm cell"); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + // Erase curr lstm op being replaced. + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], + model); + DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], + model); + model->operators.erase(FindOp(*model, curr_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc new file mode 100644 index 0000000000..910a960589 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.cc @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" + +namespace toco { + +void CreateOptionalArray(Model* model, string* input_array_buffer, + const string& array_name) { + *input_array_buffer = array_name; + model->CreateOptionalArray(array_name); +} + +void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer, + int src_stride, int src_start_idx1, int src_start_idx2, + Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride, + int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, + int dim2_copy_size) { + int src_offset = src_start_idx1 * src_stride + src_start_idx2; + int dst_offset = dst_start_idx1 * dst_stride + dst_start_idx2; + for (int i = 0; i < dim1_copy_size; i++) { + for (int j = 0; j < dim2_copy_size; j++) { + int idx_src = src_offset + i * src_stride + j; + int idx_dst = dst_offset + i * dst_stride + j; + dst_buffer->data[idx_dst] = src_buffer.data[idx_src]; + } + } +} + +Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model, + string* array_name, + const Shape& shape) { + *array_name = AvailableArrayName(*model, *array_name); + auto& array = model->GetOrCreateArray(*array_name); + array.data_type = ArrayDataType::kFloat; + array.copy_shape(shape); + Buffer<ArrayDataType::kFloat>* buffer = + &(array.GetMutableBuffer<ArrayDataType::kFloat>()); + buffer->data.resize(RequiredBufferSizeForShape(shape)); + return buffer; +} + +void CopySubArrayToArray(Model* model, string* array_name, + const string& tensor_name, int dim1_size, + int dim2_size, const Array& original_array, + int start_idx1, int start_idx2) { + // Determine whether it's bias or not, create shape, buffer. + bool is_bias = dim2_size == 1; + Shape shape = is_bias ? Shape({dim1_size}) : Shape({dim1_size, dim2_size}); + Buffer<ArrayDataType::kFloat>* buffer = + CreateFloatArrayBuffer(model, array_name, shape); + auto& orig_buffer = original_array.GetBuffer<ArrayDataType::kFloat>(); + + // Copy data from big tensor. + CopyArrayData(orig_buffer, is_bias ? 1 : original_array.shape().dims(1), + start_idx1, start_idx2, buffer, dim2_size, 0, 0, dim1_size, + dim2_size); +} + +void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer, + int tensor_stride, const Array& sub_array, + int start_idx1, int start_idx2) { + // Get tensor data. + bool is_bias = sub_array.shape().dims().size() == 1; + int dim1_copy_size = sub_array.shape().dims()[0]; + int dim2_copy_size = is_bias ? 1 : sub_array.shape().dims(1); + auto& sub_buffer = sub_array.GetBuffer<ArrayDataType::kFloat>(); + + // Copy data from sub tensor. + CopyArrayData(sub_buffer, dim2_copy_size, 0, 0, &tensor_buffer, + is_bias ? 1 : tensor_stride, start_idx1, start_idx2, + dim1_copy_size, dim2_copy_size); +} + +bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, + string* rnn_array) { + for (const auto& rnn_state : model->flags.rnn_states()) { + if (rnn_state.back_edge_source_array() == back_edge_source_array) { + *rnn_array = rnn_state.state_array(); + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h new file mode 100644 index 0000000000..881c2d4dc8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <string> +#include <vector> + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +// For consistency with the parameters defined in extended LstmCell's kernel +// (tensorflow/contrib/lite/kernels/lstm.cc), +// use lowercase for these constants. + +enum ExtendedLstmCellInputs { + kInputTensor = 0, + kInputToInputWeightsTensor = 1, // Optional + kInputToForgetWeightsTensor = 2, + kInputToCellWeightsTensor = 3, + kInputToOutputWeightsTensor = 4, + kRecurrentToInputWeightsTensor = 5, // Optional + kRecurrentToForgetWeightsTensor = 6, + kRecurrentToCellWeightsTensor = 7, + kRecurrentToOutputWeightsTensor = 8, + kCellToInputWeightsTensor = 9, // Optional + kCellToForgetWeightsTensor = 10, // Optional + kCellToOutputWeightsTensor = 11, // Optional + kInputGateBiasTensor = 12, // Optional + kForgetGateBiasTensor = 13, + kCellGateBiasTensor = 14, + kOutputGateBiasTensor = 15, + kProjectionWeightsTensor = 16, // Optional + kProjectionBiasTensor = 17, // Optional + kExtendedLstmInputCount = 18 +}; + +enum ExtendedLstmCellOutputs { + kScratchBufferTensor = 0, + kOutputStateTensor = 1, + kCellStateTensor = 2, + kOutputTensor = 3 +}; + +// Create optional array used for optional tensor in ExtendedLstmCell inputs. +void CreateOptionalArray(Model* model, string* input_array_buffer, + const string& array_name); + +// Create float array and get its buffer. +Buffer<ArrayDataType::kFloat>* CreateFloatArrayBuffer(Model* model, + string* array_name, + const Shape& shape); + +// Copy data from one array to the other one (supports 1D and 2D array), +// for 1D array, the 2nd dim's size is 1. +// Arguments: +// src_buffer: the source buffer +// src_stride: the stride of source buffer, i.e., 2nd dim's size +// src_start_idx1: the 1st dim index of start point in src matrix +// src_start_idx2: the 2nd dim index of start point in src matrix +// dst_buffer: the destination buffer +// dst_stride: the stride of destination buffer, i.e., 2nd dim's size +// dst_start_idx1: the 1st dim index of start point in dst matrix +// dst_start_idx2: the 2nd dim index of start point in dst matrix +// dim1_copy_size: 1st dim size of copy data +// dim2_copy_size: 2nd dim size of copy data +void CopyArrayData(const Buffer<ArrayDataType::kFloat>& src_buffer, + int src_stride, int src_start_idx1, int src_start_idx2, + Buffer<ArrayDataType::kFloat>* dst_buffer, int dst_stride, + int dst_start_idx1, int dst_start_idx2, int dim1_copy_size, + int dim2_copy_size); + +// Copy a subset of array data and create a smaller array, +// mostly used for spliting weights and bias for Lstm cell. +void CopySubArrayToArray(Model* model, string* array_name, + const string& tensor_name, int dim1_size, + int dim2_size, const Array& original_array, + int start_idx1, int start_idx2); + +// Copy array data to a large array's submatrix, +// mostly used for merging weights and bias for Lstm cell. +void CopyArrayToSubArray(Buffer<ArrayDataType::kFloat>& tensor_buffer, + int tensor_stride, const Array& sub_array, + int start_idx1, int start_idx2); + +// Get mating rnn array inputs using rnn_states flag. +bool GetMatchingRnnArray(Model* model, const string& back_edge_source_array, + string* rnn_array); + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index fa7e70d90b..3de251ed70 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -61,23 +61,42 @@ void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); } -void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, - const Shape& input_shape2, +void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, + const Shape& input_shape_y, Array* output_array) { - const int size1 = RequiredBufferSizeForShape(input_shape1); - const int size2 = RequiredBufferSizeForShape(input_shape2); - if (size1 > size2) { - output_array->copy_shape(input_shape1); - } else if (size2 > size1) { - output_array->copy_shape(input_shape2); - } else { - CHECK_EQ(size1, size2); - const int dims1 = input_shape1.dimensions_count(); - const int dims2 = input_shape2.dimensions_count(); - if (dims1 >= dims2) { - output_array->copy_shape(input_shape1); + // This matches the code in BroadcastBinaryOpShapeFn from tensorflow. + // It zips together the two input shapes and pads with 1 to make them the + // same length. For each dimension we broadcast if either dimension is 1 and + // otherwise expect them to match. + int rank_x = input_shape_x.dimensions_count(); + int rank_y = input_shape_y.dimensions_count(); + int rank_out = std::max(rank_x, rank_y); + std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims(); + dims_out->clear(); + dims_out->reserve(rank_out); + for (int i = 0; i < rank_out; ++i) { + int dim_x = i < (rank_out - rank_x) + ? 1 + : input_shape_x.dims(i - (rank_out - rank_x)); + bool dim_y_is_one = i < (rank_out - rank_y); + int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y)); + if (dim_x == -1 || dim_y == -1) { + // One or both dimensions is unknown. + QCHECK(false) << "Shapes must be specified"; + } else if (dim_x == 1 || dim_y == 1) { + // Broadcast one dimension to the other that is 1. + if (dim_x == 1 && !dim_y_is_one) { + // Broadcast dim_y to dim_x (1). + dims_out->push_back(dim_y); + } else { + // Broadcast dim_x to dim_y (1). + DCHECK_EQ(dim_y, 1); + dims_out->push_back(dim_x); + } } else { - output_array->copy_shape(input_shape2); + // Expect the dimensions to match. + CHECK_EQ(dim_x, dim_y) << "Dimensions must match"; + dims_out->push_back(dim_x); } } CHECK(output_array->has_shape()); @@ -728,9 +747,8 @@ void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { } void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { - // I/O arrays should be allocated on creation of op. - QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); - QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + // Only required for compact LstmCell with default NUM_INPUTS of inputs. + if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return; const auto& input_array = model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]); @@ -1218,7 +1236,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { std::vector<int32> const& perm = perm_array.GetBuffer<ArrayDataType::kInt32>().data; CHECK_EQ(perm.size(), input_shape.dimensions_count()) - << "Transpose permutation input must be same length as input dimensions"; + << "Transpose permutation input " << op->inputs[0] + << " must be same length as input dimensions"; std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims(); for (int i = 0; i < perm.size(); i++) { int axis = perm[i]; @@ -1275,6 +1294,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kSoftmax: + case OperatorType::kLogSoftmax: case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: @@ -1424,6 +1444,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); break; + case OperatorType::kBatchMatMul: case OperatorType::kTensorFlowMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 139c19022e..d7f804ee43 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -49,7 +49,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || - type == OperatorType::kDepthToSpace; + type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell; } template <ArrayDataType A> @@ -104,6 +104,9 @@ void QuantizeArray(GraphTransformation* transformation, Model* model, case ArrayDataType::kUint8: return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name, quantization_params); + case ArrayDataType::kInt16: + return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name, + quantization_params); case ArrayDataType::kInt32: return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name, quantization_params); @@ -172,36 +175,62 @@ bool ChooseQuantizationForOperatorInput( if (array.data_type != ArrayDataType::kFloat) { return false; } + + // Quantization of bias vectors + bool is_bias_vector = false; + int activations_input_index; + int weights_input_index; if (op.type == OperatorType::kConv || op.type == OperatorType::kDepthwiseConv || op.type == OperatorType::kFullyConnected) { if (input_index == 2) { - // Quantization of bias vector. - // We need both of the mandatory inputs (input activations and weights) to - // have - // been already quantized. - const auto& input_activations = model->GetArray(op.inputs[0]); - const auto& input_weights = model->GetArray(op.inputs[1]); - if (!input_activations.quantization_params || - !input_weights.quantization_params) { - return false; - } - const auto input_activations_scale = - input_activations.quantization_params->scale; - const auto input_weights_scale = input_weights.quantization_params->scale; - quantization_params->scale = - input_activations_scale * input_weights_scale; - quantization_params->zero_point = 0; - *quantized_data_type = ArrayDataType::kInt32; - transformation->AddMessageF( - "Input array %s is a bias vector. Choosing quantization params " - "accordingly.", - input); - return true; + is_bias_vector = true; + activations_input_index = 0; + weights_input_index = 1; + } + } + if (op.type == OperatorType::kLstmCell) { + if (input_index == LstmCellOperator::BIASES_INPUT) { + is_bias_vector = true; + activations_input_index = LstmCellOperator::DATA_INPUT; + weights_input_index = LstmCellOperator::WEIGHTS_INPUT; } } + if (is_bias_vector) { + // Quantization of bias vector. + // We need both of the mandatory inputs (input activations and weights) to + // have been already quantized. + const auto& input_activations = + model->GetArray(op.inputs[activations_input_index]); + const auto& input_weights = model->GetArray(op.inputs[weights_input_index]); + if (!input_activations.quantization_params || + !input_weights.quantization_params) { + return false; + } + const auto input_activations_scale = + input_activations.quantization_params->scale; + const auto input_weights_scale = input_weights.quantization_params->scale; + quantization_params->scale = input_activations_scale * input_weights_scale; + quantization_params->zero_point = 0; + *quantized_data_type = ArrayDataType::kInt32; + transformation->AddMessageF( + "Input array %s is a bias vector. Choosing quantization params " + "accordingly.", + input); + return true; + } const MinMax& minmax = GetOrComputeMinMax(model, input); + + if (op.type == OperatorType::kLstmCell) { + if (input_index == LstmCellOperator::PREV_STATE_INPUT) { + GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>( + model->flags, minmax, quantization_params); + *quantized_data_type = ArrayDataType::kInt16; + return true; + } + } + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax, quantization_params); transformation->AddMessageF( @@ -265,7 +294,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput( if (op.type == OperatorType::kTanh) { // Tanh has the range: [-1, 1]. *quantized_data_type = ArrayDataType::kUint8; - quantization_params->zero_point = 127; + quantization_params->zero_point = 128; quantization_params->scale = 1. / 128.; // 0 should be exactly representable, as values will typically be centered // around 0, with many values near 0. @@ -310,6 +339,15 @@ bool ChooseQuantizationForOperatorOutput( return true; } const MinMax& minmax = GetOrComputeMinMax(model, output); + if (op.type == OperatorType::kLstmCell) { + if (output_index == LstmCellOperator::STATE_OUTPUT || + output_index == LstmCellOperator::ACTIV_TEMP) { + GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>( + model->flags, minmax, quantization_params); + *quantized_data_type = ArrayDataType::kInt16; + return true; + } + } GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax, quantization_params); *quantized_data_type = ArrayDataType::kUint8; @@ -405,41 +443,52 @@ bool Quantize::Run(Model* model, std::size_t op_index) { if (ChooseQuantizationForOperatorInput(this, model, op, input_index, &quantized_data_type, &quantization_params)) { - changed = true; const auto& input = op.inputs[input_index]; if (IsConstantParameterArray(*model, input)) { QuantizeArray(this, model, input, quantized_data_type, quantization_params); - } else if (toco::IsRnnStateArray(*model, input)) { - // Simply Quantize the Array - auto& array = model->GetArray(op.inputs[input_index]); - array.GetOrCreateQuantizationParams() = quantization_params; - array.data_type = quantized_data_type; + changed = true; } else { auto dequantize_it = FindOpWithOutput(*model, input); - CHECK(dequantize_it != model->operators.end()) - << "Cannot quantize input \"" << input - << "\" on operator with output \"" << op.outputs[0] - << "\". Nothing feeding input."; - auto* dequantize_op = dequantize_it->get(); - CHECK(dequantize_op->type == OperatorType::kDequantize) - << "Cannot quantize input \"" << input - << "\" on operator with output \"" << op.outputs[0] - << "\". Input is not fed by a Dequantize operator."; - op.inputs[input_index] = dequantize_op->inputs[0]; - // Check if the output of that Dequantize op was not used by any - // other operator. We will then erase that Dequantize op. - if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { - // If any of the model's output_arrays was pointing to the - // Dequantize op's output, let it point to the Dequantize op's - // input instead. - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + if (dequantize_it != model->operators.end()) { + auto* dequantize_op = dequantize_it->get(); + CHECK(dequantize_op->type == OperatorType::kDequantize); + op.inputs[input_index] = dequantize_op->inputs[0]; + // Check if the output of that Dequantize op was not used by any + // other operator. We will then erase that Dequantize op. + if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + model->EraseArray(dequantize_op->outputs[0]); + model->operators.erase(dequantize_it); + } + changed = true; + } else { + // This input array is not produced by a Dequantize op. + // We have encountered this situation in RNN graphs, whose cyclic + // nature defeats the basic assumption underlying the quantization + // algorithm implemented here. For now, when we have seen this + // happening, the array in question was a RNN state array itself, + // so let us just implement this case here, and guard that assumption + // with a CHECK. A more general fix would involve revisiting the + // design of this whole Quantization transformation. + bool is_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (rnn_state.state_array() == input) { + is_rnn_state_array = true; + break; } } - model->EraseArray(dequantize_op->outputs[0]); - model->operators.erase(dequantize_it); + CHECK(is_rnn_state_array); + QuantizeArray(this, model, input, quantized_data_type, + quantization_params); + changed = true; } } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc new file mode 100644 index 0000000000..0cbbcd7c81 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsSliceTrivial(const Model& model, const Operator& op, + RemoveTrivialSlice* transformation) { + CHECK(op.type == OperatorType::kSlice); + + // Slices are trivial if they are slicing the entire input contents. + const auto& input_array = model.GetArray(op.inputs[0]); + const auto& output_array = model.GetArray(op.outputs[0]); + if (input_array.has_shape() && output_array.has_shape()) { + if (input_array.shape() == output_array.shape()) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal", + LogName(op)); + return true; + } + } + + return false; +} + +} // namespace + +bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* slice_op = reshape_it->get(); + if (slice_op->type != OperatorType::kSlice) { + return false; + } + + if (!IsSliceTrivial(*model, *slice_op, this)) { + return false; + } + + AddMessageF("Removing trivial %s", LogName(*slice_op)); + + CHECK_EQ(slice_op->inputs.size(), 3); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index db68968bad..064810b53e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -190,7 +190,7 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove all the resolved arrays. for (const string& input_name : concat_op->inputs) { // Check to prevent removal of shared tensors - if(CountOpsWithInput(*model, input_name) == 1) { + if (CountOpsWithInput(*model, input_name) == 1) { model->EraseArray(input_name); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 81fe37d7e0..944901ece7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -50,6 +50,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { output_array.data_type = ArrayDataType::kFloat; CHECK(!output_array.buffer); const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>(); + output_array.GetOrCreateMinMax() = *fakequant_op->minmax; auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>(); const int size = input_buffer.data.size(); output_buffer.data.resize(size); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc new file mode 100644 index 0000000000..4f984bfde5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -0,0 +1,180 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Transposes an array up to rank 4. +// This is ShuffleArrayTemplate with non-enum permutation. +template <ArrayDataType Type> +void Transpose(Model* model, const Array& input_array, + const std::vector<int>& perm, Array* output_array) { + const Shape& input_shape = input_array.shape(); + const std::vector<DataType<Type>>& input_data = + input_array.GetBuffer<Type>().data; + + const Shape& output_shape = output_array->shape(); + std::vector<DataType<Type>>& output_data = + output_array->GetMutableBuffer<Type>().data; + output_data.resize(RequiredBufferSizeForShape(output_shape)); + + CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); + const int dim = input_shape.dimensions_count(); + CHECK_LE(dim, 4); + CHECK(perm.size() >= dim); + for (int i = 0; i < dim; i++) { + CHECK(perm[i] >= 0 && perm[i] < dim); + CHECK(input_shape.dims(perm[i]) == output_shape.dims(i)); + } + Shape extended_input_shape = input_shape; + ExtendShape(&extended_input_shape, 4); + Shape extended_output_shape = output_shape; + ExtendShape(&extended_output_shape, 4); + std::vector<int> extended_perm; + ExtendShuffle(perm, 4, &extended_perm); + + const std::vector<int>& extended_input_dims = extended_input_shape.dims(); + const std::vector<int>& extended_output_dims = extended_output_shape.dims(); + + // TODO(starka): Rework to handle different numbers of dimensions. + int input_strides[4]; + input_strides[3] = 1; + input_strides[2] = extended_input_dims[3]; + input_strides[1] = input_strides[2] * extended_input_dims[2]; + input_strides[0] = input_strides[1] * extended_input_dims[1]; + const int input_stride_0 = input_strides[extended_perm[3]]; + const int input_stride_1 = input_strides[extended_perm[2]]; + const int input_stride_2 = input_strides[extended_perm[1]]; + const int input_stride_3 = input_strides[extended_perm[0]]; + + const int output_size_0 = extended_output_dims[3]; + const int output_size_1 = extended_output_dims[2]; + const int output_size_2 = extended_output_dims[1]; + const int output_size_3 = extended_output_dims[0]; + const int output_stride_0 = 1; + const int output_stride_1 = output_size_0; + const int output_stride_2 = output_stride_1 * output_size_1; + const int output_stride_3 = output_stride_2 * output_size_2; + + for (int i3 = 0; i3 < output_size_3; i3++) { + const DataType<Type>* const input_ptr_3 = + input_data.data() + i3 * input_stride_3; + DataType<Type>* const output_ptr_3 = + output_data.data() + i3 * output_stride_3; + for (int i2 = 0; i2 < output_size_2; i2++) { + const DataType<Type>* const input_ptr_2 = + input_ptr_3 + i2 * input_stride_2; + DataType<Type>* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; + for (int i1 = 0; i1 < output_size_1; i1++) { + const DataType<Type>* input_ptr = input_ptr_2 + i1 * input_stride_1; + DataType<Type>* output_ptr = output_ptr_2 + i1 * output_stride_1; + DataType<Type>* const output_ptr_end = + output_ptr + output_size_0 * output_stride_0; + while (output_ptr != output_ptr_end) { + *output_ptr = *input_ptr; + input_ptr += input_stride_0; + output_ptr += output_stride_0; + } + } + } + } +} + +} // namespace + +bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kTranspose) { + return false; + } + const auto* op = static_cast<const TransposeOperator*>(base_op); + + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require constant inputs. + if (!IsConstantParameterArray(*model, op->inputs[0]) || + !IsConstantParameterArray(*model, op->inputs[1])) { + return false; + } + const Array& input_array = model->GetArray(op->inputs[0]); + + if (input_array.minmax) { + output_array.GetOrCreateMinMax() = input_array.GetMinMax(); + } + + if (op->perm.empty()) { + // Yield until perm has been populated by ResolveTransposeAttributes. + return false; + } + + // We currently only support 1-4 dimensions. + CHECK_LE(op->perm.size(), 4); + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Transpose<ArrayDataType::kFloat>(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kUint8: + Transpose<ArrayDataType::kUint8>(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kInt32: + Transpose<ArrayDataType::kInt32>(model, input_array, op->perm, + &output_array); + break; + case ArrayDataType::kInt64: + Transpose<ArrayDataType::kInt64>(model, input_array, op->perm, + &output_array); + break; + default: + LOG(FATAL) << "Unsupported data type given to Transpose op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used. + for (const auto& input : op->inputs) { + if (IsDiscardableArray(*model, input) && + CountOpsWithInput(*model, input) == 1) { + model->EraseArray(input); + } + } + + // Erase the operator. + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index 5c68f87f6c..bc70db0bd8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -60,16 +60,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { const auto& output_array_name = reorder_op->outputs[0]; auto& input_array = model->GetArray(input_array_name); auto& output_array = model->GetArray(output_array_name); - string constant_input_array_name = input_array_name; if (!input_array.buffer) { - const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); - if (op_producing_input && - op_producing_input->type == OperatorType::kFakeQuant) { - constant_input_array_name = op_producing_input->inputs[0]; - } - } - auto& constant_input_array = model->GetArray(constant_input_array_name); - if (!constant_input_array.buffer) { return false; } // Yield until output dims have been resolved. @@ -77,14 +68,14 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { return false; } // Reorder the input array dims and buffer data - if (constant_input_array.buffer->type == ArrayDataType::kFloat) { - ReorderAxes<float, ArrayDataType::kFloat>( - reorder_op->input_axes_order, reorder_op->output_axes_order, - &constant_input_array, &output_array); - } else if (constant_input_array.buffer->type == ArrayDataType::kInt32) { - ReorderAxes<uint8, ArrayDataType::kUint8>( - reorder_op->input_axes_order, reorder_op->output_axes_order, - &constant_input_array, &output_array); + if (input_array.buffer->type == ArrayDataType::kFloat) { + ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order, + reorder_op->output_axes_order, + &input_array, &output_array); + } else if (input_array.buffer->type == ArrayDataType::kInt32) { + ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order, + reorder_op->output_axes_order, + &input_array, &output_array); } else { LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8."; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index ad1e56888e..f38203c80f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -29,7 +29,36 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { return false; } - const auto* matmul_op = matmul_it->get(); + const auto* matmul_op = + static_cast<const TensorFlowMatMulOperator*>(matmul_it->get()); + + // Reorder the axes on the second input. TensorFlow uses row-major ordering + // on both inputs, however this is inefficient for the FullyConnected + // operator. We'll transpose the second input to be in column-major order now + // and let constant propagation optimize things (if possible). + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + matmul_op->inputs[1], + CreateInt32Array( + model, + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"), + {1, 0})}; + transpose_op->outputs = { + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + + // Refresh iterator. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } + DCHECK_EQ(matmul_it->get(), matmul_op); + + string input_lhs = matmul_op->inputs[0]; + string input_rhs = transpose_op->outputs[0]; // Find the op producing the array passed to this MatMul auto previous_op_it = model->operators.begin(); @@ -47,22 +76,26 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } Operator* previous_op = (found) ? previous_op_it->get() : nullptr; - // construct the new FullyConnectedOperator + // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->outputs = matmul_op->outputs; - // insert the newly constructed FullyConnectedOperator - auto fc_it = model->operators.emplace(matmul_it, fc_op); + // Insert the newly constructed FullyConnectedOperator. + model->operators.emplace(matmul_it, fc_op) + 1; - // refresh invalidated iterator - matmul_it = fc_it + 1; + // Refresh iterator. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } DCHECK_EQ(matmul_it->get(), matmul_op); // The way that TensorFlow encodes FullyConnected ops is as a pair // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the - // MatMul - // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the - // input doesn't need reshaping, so we can't just match (Reshape, MatMul) + // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if + // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), @@ -72,7 +105,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { model->EraseArray(previous_op_output); } CHECK_EQ(previous_op->inputs.size(), 2); - fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; + input_lhs = previous_op->inputs[0]; // Only remove Reshape node if no other node uses its output. if (CountOpsWithInput(*model, previous_op_output) == 1) { const auto& previous_op_shape = previous_op->inputs[1]; @@ -95,9 +128,10 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } else { AddMessageF("Replacing %s by a FullyConnected operator", LogName(*matmul_op)); - fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]}; } + fc_op->inputs = {input_lhs, input_rhs}; + // erase the MatMul operator model->operators.erase(matmul_it); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 8931498782..2f94f9cd8a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -18,6 +18,17 @@ tf_cc_test( ], ) +tf_cc_test( + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc new file mode 100644 index 0000000000..6aae0775d3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <tuple> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<testing::Matcher<float>> ArrayFloatNear( + const std::vector<float>& values, float max_abs_error = 1e-5) { + std::vector<testing::Matcher<float>> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +class CopyArrayDataTest : public ::testing::Test { + public: + CopyArrayDataTest() {} + + void PrepareBuffers(Model* model, std::initializer_list<float> src_data, + int src_dim_1, int src_dim_2, + std::initializer_list<float> dst_data, int dst_dim_1, + int dst_dim_2) { + string src_array = "src_array"; + src_buffer_ = CreateFloatArrayBuffer( + model, &src_array, + src_dim_2 == 1 ? Shape({src_dim_1}) : Shape({src_dim_1, src_dim_2})); + PopulateBuffer(src_buffer_, src_data); + string dst_array = "dst_array"; + dst_buffer_ = CreateFloatArrayBuffer( + model, &dst_array, + dst_dim_2 == 1 ? Shape({dst_dim_1}) : Shape({dst_dim_1, dst_dim_2})); + PopulateBuffer(dst_buffer_, dst_data); + } + + Buffer<ArrayDataType::kFloat>* GetSrcBuffer() { return src_buffer_; } + Buffer<ArrayDataType::kFloat>* GetDstBuffer() { return dst_buffer_; } + + void PopulateBuffer(Buffer<ArrayDataType::kFloat>* buffer, + const std::vector<float>& init_data) { + for (int i = 0; i < init_data.size(); i++) { + buffer->data[i] = init_data[i]; + } + } + void UpdateBuffer(Buffer<ArrayDataType::kFloat>* buffer, + std::initializer_list<float> data) { + buffer->data.resize(data.size()); + PopulateBuffer(buffer, data); + } + + private: + Buffer<ArrayDataType::kFloat>* src_buffer_; + Buffer<ArrayDataType::kFloat>* dst_buffer_; +}; + +// Copy from 1 big 2D array to 8 smaller ones. +TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes2D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list<float> large_tf_weight_data = { + -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, + -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, + 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, + 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, + 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, + -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, + 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, + -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, + 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, + -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, + -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, + -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, + -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, + 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, + 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, + 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, + -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, + 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, + -0.446775, -0.418363, 0.019743, 0.042025}; + std::initializer_list<float> tflite_lstm_input_weight = {0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, + /*src_dim_2=*/7, tflite_lstm_input_weight, + /*dst_dim_1=*/4, /*dst_dim_2=*/3); + + // Copy src starts at (0,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + std::vector<float> expected = {-0.320407, -0.108683, 0.406358, 0.170866, + 0.084135, 0.201878, 0.045578, 0.149816, + -0.447073, -0.001985, 0.402193, 0.315517}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (4,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/4, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, + -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (8,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/8, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, + -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (12,0), size (4,3). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/12, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/3, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + expected = {0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, + -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // New dst_buffer with size 16. + std::initializer_list<float> tflite_lstm_recurrent_weight = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_weight_data, /*src_dim_1=*/16, + /*src_dim_2=*/7, tflite_lstm_recurrent_weight, + /*dst_dim_1=*/4, /*dst_dim_2=*/4); + + // Copy src starts at (0,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/0, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, + 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, + 0.38258, 0.43599, 0.11986, 0.465195}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (4,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/4, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, + 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, + 0.204032, -0.375317, -0.041911, 0.051664}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (8,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/8, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, + -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, + 0.048654, -0.38582, 0.411018, -0.315606}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy src starts at (12,3), size (4,4). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/7, /*src_start_idx1=*/12, + /*src_start_idx2=*/3, GetDstBuffer(), /*dst_stride=*/4, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + expected = {-0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, + -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, + -0.446775, -0.418363, 0.019743, 0.042025}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 1 big 1D array to 4 small ones. +TEST_F(CopyArrayDataTest, CopyFromBigArrayToSmallerArrayes1D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list<float> large_tf_bias_data = { + 0.980304, 0.419808, 0.080278, 0.728548, 0.581674, 0.672433, + 0.434190, 0.844357, 0.229587, 0.785629, 0.022065, 0.753082, + 0.422080, 0.539481, 0.878386, 0.168965}; + std::initializer_list<float> tflite_lstm_i_bias = {0, 0, 0, 0}; + PrepareBuffers(&model, large_tf_bias_data, /*src_dim_1=*/16, + /*src_dim_2=*/1, tflite_lstm_i_bias, + /*dst_dim_1=*/4, /*dst_dim_2=*/1); + + // Copy starts at (0,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (4,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/4, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.581674, 0.672433, 0.434190, 0.844357}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (8,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/8, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.229587, 0.785629, 0.022065, 0.753082}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); + + // Copy starts at (12,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/12, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + expected = {0.422080, 0.539481, 0.878386, 0.168965}; + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 8 small 2D arrayes to 1 big one. +TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray2D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list<float> large_tf_weights_data = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Copy dst starts (0, 0), size (4, 3). + std::initializer_list<float> tflite_lstm_i2i_weight = { + -0.320407, -0.108683, 0.406358, 0.170866, 0.084135, 0.201878, + 0.045578, 0.149816, -0.447073, -0.001985, 0.402193, 0.315517}; + PrepareBuffers(&model, tflite_lstm_i2i_weight, /*src_dim_1=*/4, + /*src_dim_2=*/3, large_tf_weights_data, + /*dst_dim_1=*/16, /*dst_dim_2=*/7); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (4, 0), size (4, 3). + std::initializer_list<float> tflite_lstm_i2c_weight = { + 0.33548, -0.118789, -0.414159, -0.086274, 0.186188, -0.324923, + -0.463082, -0.231706, -0.487465, 0.459106, -0.171447, -0.006542}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2c_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (8, 0), size (4, 3). + std::initializer_list<float> tflite_lstm_i2f_weight = { + 0.320483, 0.155899, 0.156555, -0.052532, 0.134631, -0.257957, + -0.08754, -0.109447, -0.502462, -0.070774, 0.495683, -0.475088}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2f_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (12, 0), size (4, 3). + std::initializer_list<float> tflite_lstm_i2o_weight = { + 0.349628, 0.21698, 0.258989, -0.089025, -0.417513, 0.07609, + -0.447049, 0.013125, 0.278503, -0.432371, 0.153714, -0.047403}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_i2o_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/3, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/3); + + // Copy dst starts (0, 3), size (4, 4). + std::initializer_list<float> tflite_lstm_i2r_weight = { + -0.410811, -0.285786, -0.15769, -0.194201, 0.21519, -0.284458, + 0.495906, -0.073818, -0.453578, 0.116766, 0.21808, 0.047326, + 0.38258, 0.43599, 0.11986, 0.465195}; + UpdateBuffer(GetSrcBuffer(), tflite_lstm_i2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (4, 3), size (4, 4). + std::initializer_list<float> tflite_lstm_c2r_weight = { + 0.049269, 0.156108, 0.093459, -0.129103, 0.4117, -0.344439, + 0.240465, -0.343331, -0.186592, -0.020756, -0.239007, 0.364817, + 0.204032, -0.375317, -0.041911, 0.051664}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_c2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (8, 3), size (4, 4). + std::initializer_list<float> tflite_lstm_f2r_weight = { + -0.249823, -0.353107, 0.031563, -0.340771, -0.50141, 0.486939, + -0.43853, 0.268426, -0.028055, -0.121838, -0.046016, 0.105309, + 0.048654, -0.38582, 0.411018, -0.315606}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_f2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + // Copy dst starts (12, 3), size (4, 4). + std::initializer_list<float> tflite_lstm_o2r_weight = { + -0.097902, 0.331218, 0.034602, 0.418069, 0.393821, 0.404733, + -0.055418, -0.43903, 0.459869, 0.143755, -0.177335, -0.162247, + -0.446775, -0.418363, 0.019743, 0.042025}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_o2r_weight); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/4, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/7, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/3, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/4); + + std::vector<float> expected = { + -0.320407, -0.108683, 0.406358, -0.410811, -0.285786, -0.15769, + -0.194201, 0.170866, 0.084135, 0.201878, 0.21519, -0.284458, + 0.495906, -0.073818, 0.045578, 0.149816, -0.447073, -0.453578, + 0.116766, 0.21808, 0.047326, -0.001985, 0.402193, 0.315517, + 0.38258, 0.43599, 0.11986, 0.465195, 0.33548, -0.118789, + -0.414159, 0.049269, 0.156108, 0.093459, -0.129103, -0.086274, + 0.186188, -0.324923, 0.4117, -0.344439, 0.240465, -0.343331, + -0.463082, -0.231706, -0.487465, -0.186592, -0.020756, -0.239007, + 0.364817, 0.459106, -0.171447, -0.006542, 0.204032, -0.375317, + -0.041911, 0.051664, 0.320483, 0.155899, 0.156555, -0.249823, + -0.353107, 0.031563, -0.340771, -0.052532, 0.134631, -0.257957, + -0.50141, 0.486939, -0.43853, 0.268426, -0.08754, -0.109447, + -0.502462, -0.028055, -0.121838, -0.046016, 0.105309, -0.070774, + 0.495683, -0.475088, 0.048654, -0.38582, 0.411018, -0.315606, + 0.349628, 0.21698, 0.258989, -0.097902, 0.331218, 0.034602, + 0.418069, -0.089025, -0.417513, 0.07609, 0.393821, 0.404733, + -0.055418, -0.43903, -0.447049, 0.013125, 0.278503, 0.459869, + 0.143755, -0.177335, -0.162247, -0.432371, 0.153714, -0.047403, + -0.446775, -0.418363, 0.019743, 0.042025}; + + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +// Copy from 4 small 1D arrayes to 1 big one. +TEST_F(CopyArrayDataTest, CopyFromSmallArrayesToBigArray1D) { + // Init src_buffer, dst_buffer. + Model model; + std::initializer_list<float> large_tf_bias_data = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + + std::initializer_list<float> tflite_lstm_i_bias = {0.980304, 0.419808, + 0.080278, 0.728548}; + + PrepareBuffers(&model, tflite_lstm_i_bias, /*src_dim_1=*/4, + /*src_dim_2=*/1, large_tf_bias_data, + /*dst_dim_1=*/16, /*dst_dim_2=*/1); + + // Copy starts at (0,), size (4,). + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/0, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (4,), size (4,). + std::initializer_list<float> tflite_lstm_cell_bias = {0.581674, 0.672433, + 0.434190, 0.844357}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_cell_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/4, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (8,0), size (4,). + std::initializer_list<float> tflite_lstm_forget_bias = {0.229587, 0.785629, + 0.022065, 0.753082}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_forget_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/8, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + // Copy starts at (12,), size (4,). + std::initializer_list<float> tflite_lstm_output_bias = {0.422080, 0.539481, + 0.878386, 0.168965}; + PopulateBuffer(GetSrcBuffer(), tflite_lstm_output_bias); + CopyArrayData(*(GetSrcBuffer()), + /*src_stride=*/1, /*src_start_idx1=*/0, + /*src_start_idx2=*/0, GetDstBuffer(), /*dst_stride=*/1, + /*dst_start_idx1=*/12, /*dst_start_idx2=*/0, + /*dim1_copy_size=*/4, /*dim2_copy_size=*/1); + + std::vector<float> expected = {0.980304, 0.419808, 0.080278, 0.728548, + 0.581674, 0.672433, 0.434190, 0.844357, + 0.229587, 0.785629, 0.022065, 0.753082, + 0.422080, 0.539481, 0.878386, 0.168965}; + + EXPECT_THAT(GetDstBuffer()->data, ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc new file mode 100644 index 0000000000..da81ea2ff3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Unrolls a BatchMatMul on the batch dimension. +// We need to slice each batch out of the inputs, matmul them individually, then +// stack them all back together at the end. +// +// This transform effectively looks like: +// result_slices = [] +// for bat in B: +// slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N]) +// slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N]) +// slice_c = tf.matmul(slice_a, slice_b) +// result_slices[bat] = slice_c +// result = tf.stack(result_slices) +bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { + auto batch_op_it = model->operators.begin() + op_index; + if (batch_op_it->get()->type != OperatorType::kBatchMatMul) { + return false; + } + const auto* batch_op = + static_cast<const BatchMatMulOperator*>(batch_op_it->get()); + + // We must have the shape of at least one input to know our batch size. + const auto& input_array_a = model->GetArray(batch_op->inputs[0]); + const auto& input_array_b = model->GetArray(batch_op->inputs[1]); + if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false; + + // We only support the rank 3 case. If you are batching on rank > 3 you'll + // have to figure that out. + CHECK_EQ(input_array_a.shape().dimensions_count(), + input_array_b.shape().dimensions_count()) + << "Input dimensions must have the same rank"; + if (input_array_a.shape().dimensions_count() == 2) { + // This is really just a MatMul. This likely means that someone hand-crafted + // a graphdef with a BatchMatMul when they really wanted a MatMul. + AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator", + LogName(*batch_op)); + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = batch_op->inputs; + matmul_op->outputs = batch_op->outputs; + const auto matmul_op_it = model->operators.emplace(batch_op_it, matmul_op); + batch_op_it = matmul_op_it + 1; + CHECK_EQ(batch_op_it->get(), batch_op); + model->operators.erase(batch_op_it); + return true; + } + CHECK_EQ(input_array_a.shape().dimensions_count(), 3) + << "Input arrays must have rank 3"; + + // Perform the matmul for each slice of the batch. + int batch_count = input_array_a.shape().dims(0); + AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), + batch_count); + auto tail_it = batch_op_it; + std::vector<string> stack_inputs; + for (int batch = 0; batch < batch_count; ++batch) { + std::string batch_name = + std::string(batch_op->outputs[0]) + "_b" + std::to_string(batch); + + // tf.slice(a, ...). + auto* slice_a_op = new SliceOperator; + slice_a_op->inputs = { + batch_op->inputs[0], + CreateInt32Array(model, batch_name + "/slice_a/slice/begin", + {batch, 0, 0}), + CreateInt32Array( + model, batch_name + "/slice_a/slice/size", + {1, input_array_a.shape().dims(1), input_array_a.shape().dims(2)}), + }; + slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")}; + auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]); + slice_a_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, slice_a_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_a_reshape_op = new TensorFlowReshapeOperator; + slice_a_reshape_op->inputs = { + slice_a_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_a/reshape/shape", + {-1, input_array_a.shape().dims(2)})}; + slice_a_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_a/reshape")}; + auto& slice_a_reshape_op_output = + model->GetOrCreateArray(slice_a_reshape_op->outputs[0]); + slice_a_reshape_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, slice_a_reshape_op) + 1; + + // tf.slice(b, ...). + auto* slice_b_op = new SliceOperator; + slice_b_op->inputs = { + batch_op->inputs[1], + CreateInt32Array(model, batch_name + "/slice_b/slice/begin", {0, 0, 0}), + CreateInt32Array( + model, batch_name + "/slice_b/slice/size", + {1, input_array_b.shape().dims(1), input_array_b.shape().dims(2)}), + }; + slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")}; + auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]); + slice_b_op_output.data_type = input_array_b.data_type; + tail_it = model->operators.emplace(tail_it, slice_b_op) + 1; + + // Reshape to remove the first dimension ([1,M,N] -> [M,N]). + auto* slice_b_reshape_op = new TensorFlowReshapeOperator; + slice_b_reshape_op->inputs = { + slice_b_op->outputs[0], + CreateInt32Array(model, batch_name + "/slice_b/reshape/shape", + {-1, input_array_b.shape().dims(2)})}; + slice_b_reshape_op->outputs = { + AvailableArrayName(*model, batch_name + "/slice_b/reshape")}; + auto& slice_b_reshape_op_output = + model->GetOrCreateArray(slice_b_reshape_op->outputs[0]); + slice_b_reshape_op_output.data_type = input_array_b.data_type; + tail_it = model->operators.emplace(tail_it, slice_b_reshape_op) + 1; + + // tf.matmul(slice_a, slice_b). + auto* matmul_op = new TensorFlowMatMulOperator; + matmul_op->inputs = {slice_a_reshape_op->outputs[0], + slice_b_reshape_op->outputs[0]}; + matmul_op->outputs = {AvailableArrayName(*model, batch_name)}; + auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]); + matmul_op_output.data_type = input_array_a.data_type; + tail_it = model->operators.emplace(tail_it, matmul_op) + 1; + + // Add to stack. + stack_inputs.push_back(matmul_op->outputs[0]); + } + + // The stack that will join all the individual matmul results together. + auto* stack_op = new StackOperator; + stack_op->inputs = stack_inputs; + stack_op->outputs = {batch_op->outputs[0]}; + stack_op->axis = 0; + model->operators.emplace(tail_it, stack_op); + + // Remove the old batch matmul now that we've unrolled. + batch_op_it = model->operators.begin(); + for (; batch_op_it != model->operators.end(); ++batch_op_it) { + if (batch_op_it->get() == batch_op) { + break; + } + } + CHECK(batch_op_it != model->operators.end()); + CHECK(batch_op_it->get() == batch_op); + model->operators.erase(batch_op_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index c12706e52d..41d6c832f0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -839,6 +839,7 @@ void ConvertSwitchOperator(const NodeDef& node, op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); } + void ConvertSoftmaxOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -854,6 +855,18 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } +void ConvertLogSoftmaxOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "LogSoftmax"); + CheckInputsCount(node, tf_import_flags, 1); + const auto& input_name = node.input(0); + auto* log_softmax = new LogSoftmaxOperator; + log_softmax->inputs.push_back(input_name); + log_softmax->outputs.push_back(node.name()); + model->operators.emplace_back(log_softmax); +} + void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -962,49 +975,37 @@ void ConvertReshapeOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertBatchMatMulOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, 2); + + // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions + CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); + CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false)); + + auto* batch_matmul = new BatchMatMulOperator; + batch_matmul->inputs = {node.input(0), node.input(1)}; + batch_matmul->outputs = {node.name()}; + model->operators.emplace_back(batch_matmul); +} + void ConvertMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "MatMul") { - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); - CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); - CHECK(!HasAttr(node, "adjoint_a") || - (GetBoolAttr(node, "adjoint_a") == false)); - CHECK(!HasAttr(node, "adjoint_b") || - (GetBoolAttr(node, "adjoint_b") == false)); - } else if (node.op() == "BatchMatMul") { - // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions - CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); - CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false)); - } else { - LOG(FATAL) << "op must be 'MatMul' or 'BatchMatMul'"; - } - const auto& input_name = node.input(0); - const auto& weights_name = node.input(1); - const auto& reordered_weights_name = weights_name + "_reordered"; - // Check if a ReorderAxesOperator was already created for these weights - // (that happens when multiple layers share the same weights). - const Operator* existing_reorder = - GetOpWithOutput(*model, reordered_weights_name); - if (existing_reorder) { - // Check that it is safe to rely on the _reordered naming of the output - // array! - CHECK(existing_reorder->type == OperatorType::kReorderAxes); - } else { - // Create a new ReorderAxesOperator - auto* reorder = new ReorderAxesOperator; - reorder->inputs = {weights_name}; - reorder->outputs = {reordered_weights_name}; - reorder->input_axes_order = AxesOrder::kRC; - reorder->output_axes_order = AxesOrder::kCR; - model->operators.emplace_back(reorder); - } + // Transpose flags should be easy to support, but we don't have a + // GraphDef with them to test on at the moment. + CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); + CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); + CHECK(!HasAttr(node, "adjoint_a") || + (GetBoolAttr(node, "adjoint_a") == false)); + CHECK(!HasAttr(node, "adjoint_b") || + (GetBoolAttr(node, "adjoint_b") == false)); + auto* matmul = new TensorFlowMatMulOperator; - matmul->inputs = {input_name, reordered_weights_name}; + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); } @@ -1859,7 +1860,9 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { ConvertReshapeOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul" || node.op() == "BatchMatMul") { + } else if (node.op() == "BatchMatMul") { + ConvertBatchMatMulOperator(node, tf_import_flags, model); + } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { ConvertDivOperator(node, tf_import_flags, model); @@ -1898,6 +1901,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( ConvertLRNOperator(node, tf_import_flags, model); } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); + } else if (node.op() == "LogSoftmax") { + ConvertLogSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "All") { ConvertAllOperator(node, tf_import_flags, model); } else if (node.op() == "Assert") { diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 447618ec85..0bee694387 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -35,6 +35,7 @@ enum class OperatorType { kAdd, kAddN, kAveragePool, + kBatchMatMul, kBatchNormalization, kConv, kConcatenation, @@ -62,6 +63,7 @@ enum class OperatorType { kRelu1, kRelu6, kSoftmax, + kLogSoftmax, kSub, kTanh, kTransposeConv, @@ -159,9 +161,14 @@ enum class ArrayDataType { kNone, kBool, kFloat, + kInt8, kUint8, + kInt16, + kUint16, kInt32, + kUint32, kInt64, + kUint64, kString }; @@ -181,18 +188,38 @@ struct DataTypeImpl<ArrayDataType::kFloat> { typedef float Type; }; template <> +struct DataTypeImpl<ArrayDataType::kInt8> { + typedef int8 Type; +}; +template <> struct DataTypeImpl<ArrayDataType::kUint8> { typedef uint8 Type; }; template <> +struct DataTypeImpl<ArrayDataType::kInt16> { + typedef int16 Type; +}; +template <> +struct DataTypeImpl<ArrayDataType::kUint16> { + typedef uint16 Type; +}; +template <> struct DataTypeImpl<ArrayDataType::kInt32> { typedef int32 Type; }; template <> +struct DataTypeImpl<ArrayDataType::kUint32> { + typedef uint32 Type; +}; +template <> struct DataTypeImpl<ArrayDataType::kInt64> { typedef int64 Type; }; template <> +struct DataTypeImpl<ArrayDataType::kUint64> { + typedef uint64 Type; +}; +template <> struct DataTypeImpl<ArrayDataType::kString> { typedef string Type; }; @@ -712,6 +739,19 @@ struct TensorFlowIdentityOperator : Operator { TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} }; +// Batch matrix multiplication operator. This comes from the (deprecated) +// tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count +// and it can be trivially unrolled into a series of matmuls on each element. +// +// Inputs: +// inputs[0]: required: the left-hand side matrix +// inputs[1]: required: the right-hand side matrix +// +// TensorFlow equivalent: MatMul +struct BatchMatMulOperator : Operator { + BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {} +}; + // General matrix multiplication operator. We don't want to support general // matrix multiplication at inference time, so we resolve it during tooling // to more specific operator types, namely, FullyConnected. @@ -1216,6 +1256,16 @@ struct SoftmaxOperator : Operator { float beta = 0.f; }; +// LogSoftmax activation function. +// +// Inputs: +// inputs[0]: required: the logits input array +// +// TensorFlow equivalent: LogSoftmax +struct LogSoftmaxOperator : Operator { + LogSoftmaxOperator() : Operator(OperatorType::kLogSoftmax) {} +}; + // Cast operator. // // Inputs: @@ -1544,7 +1594,7 @@ class Model { bool HasArray(const string& name) const { return arrays.count(name) > 0; } Array& GetArray(const string& name) const { - DCHECK(HasArray(name)); + DCHECK(HasArray(name)) << "Array not found: " << name; return *arrays.at(name); } Array& GetOrCreateArray(const string& name) { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 04aaedd59d..ff54b350bf 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -524,6 +524,28 @@ class Transpose TocoOperator* op) const override {} }; +class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, + ::tflite::BuiltinOptions_LSTMOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + CHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + } +}; + class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, ::tflite::BuiltinOptions_MeanOptions> { public: @@ -779,6 +801,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); + ops.emplace_back( + new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); // Custom Operators. ops.emplace_back(new Cast("CAST", OperatorType::kCast)); diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index b715881774..5472c52c96 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -53,18 +53,22 @@ void MakeGeneralGraphTransformationsSet( CHECK(transformations->empty()); transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertTrivialAddNToAdd); + transformations->Add(new ConvertTrivialStackToReshape); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); + transformations->Add(new ResolveTransposeAttributes); transformations->Add(new PropagateArrayDataTypes); transformations->Add(new PropagateFixedSizes); transformations->Add(new RemoveTensorFlowAssert); transformations->Add(new RemoveTensorFlowIdentity); transformations->Add(new RemoveTrivialConcatenation); transformations->Add(new RemoveTrivialConcatenationInput); + transformations->Add(new RemoveTrivialSlice); transformations->Add(new RemoveUnusedOp); transformations->Add(new EnsureBiasVectors); transformations->Add(new ResolveReorderAxes); + transformations->Add(new UnrollBatchMatMul); transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); @@ -75,6 +79,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); + transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); @@ -92,9 +97,9 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); transformations->Add(new ResolveMeanAttributes); - transformations->Add(new ResolveTransposeAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); + transformations->Add(new ResolveConstantFakeQuant); } bool SupportsQuantization(FileFormat format) { @@ -106,7 +111,8 @@ bool SupportsFusedActivationFunction(FileFormat format) { } bool SupportsLstmCell(FileFormat format) { - return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT); + return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT || + format == TFLITE); } bool SupportsPreallocatedWorkspace(FileFormat format) { @@ -212,9 +218,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } else { transformations.Add(new UnfuseActivationFunctions); } - if (output_format != TENSORFLOW_GRAPHDEF) { - transformations.Add(new ResolveConstantFakeQuant); - } if (toco_flags.drop_fake_quant()) { transformations.Add(new DropFakeQuant); } else { @@ -227,9 +230,13 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } } transformations.Add(new ConvertPureConvToDepthwise); - // TFLite export does not yet support fused LSTM cell. if (SupportsLstmCell(output_format)) { transformations.Add(new IdentifyLstmCell); + if (output_format == TFLITE) { + transformations.Add(new toco::SplitLstmCellInputs); + } else { + transformations.Add(new toco::MergeLstmCellInputs); + } } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", @@ -267,6 +274,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { dequantization_transformations); } + if (output_format == TENSORFLOW_GRAPHDEF) { + EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); + } + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index ff8bc471b7..ce0fde57f4 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -159,6 +159,18 @@ std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput( return model.operators.end(); } +std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput( + Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model.operators.end(); +} + std::vector<std::unique_ptr<Operator>>::const_iterator FindOp( const Model& model, const Operator* op) { for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { @@ -217,6 +229,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Add) HANDLE_OPERATORTYPENAME_CASE(AddN) HANDLE_OPERATORTYPENAME_CASE(AveragePool) + HANDLE_OPERATORTYPENAME_CASE(BatchMatMul) HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) HANDLE_OPERATORTYPENAME_CASE(Conv) HANDLE_OPERATORTYPENAME_CASE(Concatenation) @@ -238,6 +251,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Relu6) HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) HANDLE_OPERATORTYPENAME_CASE(Softmax) + HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) @@ -406,7 +420,7 @@ void LogArray(int log_level, const Model& model, const string& name) { } if (array.quantization_params) { VLOG(log_level) << " QuantizationParams: zero_point=" - << array.quantization_params->zero_point + << static_cast<int>(array.quantization_params->zero_point) << ", scale=" << array.quantization_params->scale; } } @@ -685,12 +699,10 @@ void CheckNoMissingArray(const Model& model) { for (const auto& op : model.operators) { for (const auto& input : op->inputs) { CHECK(model.HasArray(input) || model.optional_arrays.count(input)) - << "Input: " << input << " missing for op: " - << op->outputs[0] << "."; + << "Input: " << input << " missing for op: " << op->outputs[0] << "."; } for (const auto& output : op->outputs) { - CHECK(model.HasArray(output)) << "Output: " << output - << " missing."; + CHECK(model.HasArray(output)) << "Output: " << output << " missing."; } } CheckNonExistentIOArrays(model); @@ -1296,12 +1308,23 @@ int ElementSize(ArrayDataType data_type) { switch (data_type) { case ArrayDataType::kFloat: return 4; - case ArrayDataType::kInt32: - return 4; + case ArrayDataType::kInt8: + return 1; case ArrayDataType::kUint8: return 1; + case ArrayDataType::kInt16: + return 2; + case ArrayDataType::kUint16: + return 2; + case ArrayDataType::kInt32: + return 4; + case ArrayDataType::kUint32: + return 4; case ArrayDataType::kInt64: return 8; + case ArrayDataType::kUint64: + return 8; + // Usually not critical limitation because strings are only input and/or // output. case ArrayDataType::kString: @@ -1399,6 +1422,21 @@ bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { return is_fc_weights; } +string CreateInt32Array(Model* model, const string& param_name, + const std::vector<int>& value) { + auto param_array_name = AvailableArrayName(*model, param_name); + auto& param_array = model->GetOrCreateArray(param_array_name); + param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())}); + param_array.data_type = ArrayDataType::kInt32; + auto& param_array_data = + param_array.GetMutableBuffer<ArrayDataType::kInt32>().data; + param_array_data.resize(RequiredBufferSizeForShape(param_array.shape())); + for (int i = 0; i < value.size(); ++i) { + param_array_data[i] = value[i]; + } + return param_array_name; +} + bool EstimateArithmeticOpsCount(const Model& model, int64* result) { int64 total = 0; for (const auto& op : model.operators) { @@ -1446,6 +1484,7 @@ bool EstimateArithmeticOpsCount(const Model& model, int64* result) { } case OperatorType::kLogistic: case OperatorType::kSoftmax: + case OperatorType::kLogSoftmax: case OperatorType::kTanh: { const auto& output_array = model.GetArray(op->outputs[0]); if (!output_array.has_shape()) { @@ -1545,10 +1584,6 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, } } -namespace { - -// Extend shuffle is designed to match ExtendShape, which pads the shape with -// unit dimensions at the beginning. void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim, std::vector<int>* extended_shuffle) { *extended_shuffle = input_shuffle; @@ -1563,8 +1598,6 @@ void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim, } } -} // end anonymous namespace - void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape) { if (input_axes_order == AxesOrder::kHWIM && @@ -1760,22 +1793,4 @@ void UseArraysExtraInfo(Model* model) { } } -bool IsRnnSourceArray(const toco::Model& model, const string& array_name) { - for (const auto& rnn_state : model.flags.rnn_states()) { - if (array_name == rnn_state.back_edge_source_array()) { - return true; - } - } - return false; -} - -bool IsRnnStateArray(const toco::Model& model, const string& array_name) { - for (const auto& rnn_state : model.flags.rnn_states()) { - if (array_name == rnn_state.state_array()) { - return true; - } - } - return false; -} - } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index a023bab1a0..3addccaa10 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -67,10 +67,15 @@ Operator* GetOpWithOutput(const Model& model, const string& array_name); std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput( Model& model, const string& array_name); + Operator* GetOpWithOutput(const Model& model, const string& array_name); std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput( const Model& model, const string& array_name); + +std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput( + Model& model, const string& array_name); + Operator* GetOpWithInput(const Model& model, const string& array_name); Operator* GetFirstOpWithInput(const Model& model, const string& array_name); @@ -255,6 +260,11 @@ void PrintArrayShape(Model* model, const string& name); void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector<int>* out_dims); +// Defines a constant int32 array with the provided values formatted for use +// as op parameters. +string CreateInt32Array(Model* model, const string& param_name, + const std::vector<int>& value); + bool EstimateArithmeticOpsCount(const Model& model, int64* result); int AxesCount(AxesOrder axes_order); @@ -264,6 +274,11 @@ int AxesCount(AxesOrder axes_order); void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, std::vector<int>* shuffle); +// Extend shuffle is designed to match ExtendShape, which pads the shape with +// unit dimensions at the beginning. +void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim, + std::vector<int>* extended_shuffle); + void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, AxesOrder output_axes_order, Shape* output_shape); void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, @@ -285,9 +300,6 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type); void UseArraysExtraInfo(Model* model); -bool IsRnnSourceArray(const toco::Model& model, const string& array_name); -bool IsRnnStateArray(const toco::Model& model, const string& array_name); - } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index dfebb9a679..988748ad75 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import six from tensorflow.contrib.framework.python.ops import add_arg_scope diff --git a/tensorflow/contrib/ndlstm/BUILD b/tensorflow/contrib/ndlstm/BUILD deleted file mode 100644 index 8403f84188..0000000000 --- a/tensorflow/contrib/ndlstm/BUILD +++ /dev/null @@ -1,92 +0,0 @@ -# Description: -# Contains classes implementing 1D and 2D LSTMs for image and signal -# processing problems. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") - -py_library( - name = "ndlstm", - srcs = [ - "__init__.py", - "python/__init__.py", - "python/lstm1d.py", - "python/lstm2d.py", - "python/misc.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/rnn:rnn_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - ], -) - -tf_py_test( - name = "lstm1d_test", - srcs = ["python/lstm1d_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:variables", - ], -) - -tf_py_test( - name = "lstm2d_test", - srcs = ["python/lstm2d_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variables", - ], -) - -tf_py_test( - name = "misc_test", - srcs = ["python/misc_test.py"], - additional_deps = [ - ":ndlstm", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variables", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ndlstm/README.md b/tensorflow/contrib/ndlstm/README.md deleted file mode 100644 index 7ccb57f1b3..0000000000 --- a/tensorflow/contrib/ndlstm/README.md +++ /dev/null @@ -1,31 +0,0 @@ -Library of multidimensional LSTM models and related code. - -# 2D LSTM code - -The 2D LSTM layers take tensors of the form (batch_size, height, width, -depth), compatible with convolutional layers, as inputs. The library -transposes and reshapes these tensors in a way that allows batches of -images to be processed by LSTMs. - -The library currently provides: - - - a separable 2D LSTM layer - - a simple 2D convolutional layer that can be swapped out against 2D LSTM - - layers to reduce images to sequences and images to final state vectors - - layers for sequence classification, pixel-wise classification - -# Other Dimensions - -There is 1D LSTM code in `lstm1d.py`. This code implements 1D LSTM versions -suitable as a basis for higher dimensional LSTMs. It is intended for constant -batch size and uses a different layout. Although the code is perfectly fine for -1D use, you may find other 1D LSTM implementations to be more convenient if you -are interested in sequence problems. - -# Upcoming Changes - - - PyramidLSTM - - support for 3D and 4D - - optional use of native fused LSTM op - - easy-to-use command line drivers and examples - - operators for patch-wise processing diff --git a/tensorflow/contrib/ndlstm/python/lstm1d.py b/tensorflow/contrib/ndlstm/python/lstm1d.py deleted file mode 100644 index 2e2e9086c0..0000000000 --- a/tensorflow/contrib/ndlstm/python/lstm1d.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""LSTM layers for sequences.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell -from tensorflow.python.ops import variable_scope - - -def _shape(tensor): - return tensor.get_shape().as_list() - - -def ndlstm_base_unrolled(inputs, noutput, scope=None, reverse=False): - """Run an LSTM, either forward or backward. - - This is a 1D LSTM implementation using unrolling and the TensorFlow - LSTM op. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - - Returns: - Output sequence (length, batch_size, noutput) - - """ - with variable_scope.variable_scope(scope, "SeqLstmUnrolled", [inputs]): - length, batch_size, _ = _shape(inputs) - lstm_cell = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False) - state = array_ops.zeros([batch_size, lstm_cell.state_size]) - output_u = [] - inputs_u = array_ops.unstack(inputs) - if reverse: - inputs_u = list(reversed(inputs_u)) - for i in xrange(length): - if i > 0: - variable_scope.get_variable_scope().reuse_variables() - output, state = lstm_cell(inputs_u[i], state) - output_u += [output] - if reverse: - output_u = list(reversed(output_u)) - outputs = array_ops.stack(output_u) - return outputs - - -def ndlstm_base_dynamic(inputs, noutput, scope=None, reverse=False): - """Run an LSTM, either forward or backward. - - This is a 1D LSTM implementation using dynamic_rnn and - the TensorFlow LSTM op. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - - Returns: - Output sequence (length, batch_size, noutput) - """ - with variable_scope.variable_scope(scope, "SeqLstm", [inputs]): - lstm_cell = rnn_cell.BasicLSTMCell(noutput) - if reverse: - inputs = array_ops.reverse_v2(inputs, [0]) - outputs, _ = rnn.dynamic_rnn( - lstm_cell, inputs, time_major=True, dtype=inputs.dtype) - if reverse: - outputs = array_ops.reverse_v2(outputs, [0]) - return outputs - - -def ndlstm_base(inputs, noutput, scope=None, reverse=False, dynamic=True): - """Implements a 1D LSTM, either forward or backward. - - This is a base case for multidimensional LSTM implementations, which - tend to be used differently from sequence-to-sequence - implementations. For general 1D sequence to sequence - transformations, you may want to consider another implementation - from TF slim. - - Args: - inputs: input sequence (length, batch_size, ninput) - noutput: depth of output - scope: optional scope name - reverse: run LSTM in reverse - dynamic: use dynamic_rnn - - Returns: - Output sequence (length, batch_size, noutput) - - """ - # TODO(tmb) maybe add option for other LSTM implementations, like - # slim.rnn.basic_lstm_cell - if dynamic: - return ndlstm_base_dynamic(inputs, noutput, scope=scope, reverse=reverse) - else: - return ndlstm_base_unrolled(inputs, noutput, scope=scope, reverse=reverse) - - -def sequence_to_final(inputs, noutput, scope=None, name=None, reverse=False): - """Run an LSTM across all steps and returns only the final state. - - Args: - inputs: (length, batch_size, depth) tensor - noutput: size of output vector - scope: optional scope name - name: optional name for output tensor - reverse: run in reverse - - Returns: - Batch of size (batch_size, noutput). - """ - with variable_scope.variable_scope(scope, "SequenceToFinal", [inputs]): - length, batch_size, _ = _shape(inputs) - lstm = rnn_cell.BasicLSTMCell(noutput, state_is_tuple=False) - state = array_ops.zeros([batch_size, lstm.state_size]) - inputs_u = array_ops.unstack(inputs) - if reverse: - inputs_u = list(reversed(inputs_u)) - for i in xrange(length): - if i > 0: - variable_scope.get_variable_scope().reuse_variables() - output, state = lstm(inputs_u[i], state) - outputs = array_ops.reshape(output, [batch_size, noutput], name=name) - return outputs - - -def sequence_softmax(inputs, noutput, scope=None, name=None, linear_name=None): - """Run a softmax layer over all the time steps of an input sequence. - - Args: - inputs: (length, batch_size, depth) tensor - noutput: output depth - scope: optional scope name - name: optional name for output tensor - linear_name: name for linear (pre-softmax) output - - Returns: - A tensor of size (length, batch_size, noutput). - - """ - length, _, ninputs = _shape(inputs) - inputs_u = array_ops.unstack(inputs) - output_u = [] - with variable_scope.variable_scope(scope, "SequenceSoftmax", [inputs]): - initial_w = random_ops.truncated_normal([0 + ninputs, noutput], stddev=0.1) - initial_b = constant_op.constant(0.1, shape=[noutput]) - w = variables.model_variable("weights", initializer=initial_w) - b = variables.model_variable("biases", initializer=initial_b) - for i in xrange(length): - with variable_scope.variable_scope(scope, "SequenceSoftmaxStep", - [inputs_u[i]]): - # TODO(tmb) consider using slim.fully_connected(..., - # activation_fn=tf.nn.softmax) - linear = nn_ops.xw_plus_b(inputs_u[i], w, b, name=linear_name) - output = nn_ops.softmax(linear) - output_u += [output] - outputs = array_ops.stack(output_u, name=name) - return outputs diff --git a/tensorflow/contrib/ndlstm/python/lstm1d_test.py b/tensorflow/contrib/ndlstm/python/lstm1d_test.py deleted file mode 100644 index 49b15cc814..0000000000 --- a/tensorflow/contrib/ndlstm/python/lstm1d_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for 1D LSTM.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import lstm1d as lstm1d_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -lstm1d = lstm1d_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class Lstm1DTest(test.TestCase): - - def testSequenceToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 1, 5)) - outputs = lstm1d.ndlstm_base(inputs, 8) - variables.global_variables_initializer().run() - names = [v.name for v in variables.trainable_variables()] - self.assertEqual(len(names), 2) - result = outputs.eval() - self.assertEqual(tuple(result.shape), (17, 1, 8)) - - def testSequenceToSequenceGradient(self): - with self.test_session(): - size = (17, 1, 15) - output_size = (17, 1, 8) - inputs = constant_op.constant(_rand(*size)) - outputs = lstm1d.ndlstm_base(inputs, 8, dynamic=False) - variables.global_variables_initializer().run() - gradients = gradients_impl.gradients(outputs, inputs) - if 1: # pylint: disable=using-constant-test - gradients = gradients_impl.gradients(outputs, inputs)[0].eval() - self.assertEqual(gradients.shape, size) - else: - # TODO(tmb) tf.test.compute_gradient error is currently broken - # with dynamic_rnn. Enable this test case eventually. - err = gradient_checker.compute_gradient_error( - inputs, size, outputs, output_size, delta=1e-4) - self.assert_(not np.isnan(err)) - self.assert_(err < 0.1) - - def testSequenceToSequenceGradientReverse(self): - with self.test_session(): - size = (17, 1, 15) - output_size = (17, 1, 8) - inputs = constant_op.constant(_rand(*size)) - outputs = lstm1d.ndlstm_base(inputs, 8, reverse=1, dynamic=False) - variables.global_variables_initializer().run() - if 1: # pylint: disable=using-constant-test - gradients = gradients_impl.gradients(outputs, inputs)[0].eval() - self.assertEqual(gradients.shape, size) - else: - # TODO(tmb) tf.test.compute_gradient error is currently broken - # with dynamic_rnn. Enable this test case eventually. - err = gradient_checker.compute_gradient_error( - inputs, size, outputs, output_size, delta=1e-4) - self.assert_(not np.isnan(err)) - self.assert_(err < 0.1) - - def testSequenceToFinalDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 6, 5)) - outputs = lstm1d.sequence_to_final(inputs, 8) - variables.global_variables_initializer().run() - names = [v.name for v in variables.trainable_variables()] - self.assertEqual(len(names), 2) - result = outputs.eval() - self.assertEqual(tuple(result.shape), (6, 8)) - - def testSequenceSoftmaxDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(17, 1, 5)) - outputs = lstm1d.sequence_softmax(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (17, 1, 8)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/ndlstm/python/lstm2d.py b/tensorflow/contrib/ndlstm/python/lstm2d.py deleted file mode 100644 index ebbb4ccf11..0000000000 --- a/tensorflow/contrib/ndlstm/python/lstm2d.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A small library of functions dealing with LSTMs applied to images. - -Tensors in this library generally have the shape (num_images, height, width, -depth). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.ndlstm.python import lstm1d -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope - - -def _shape(tensor): - """Get the shape of a tensor as an int list.""" - return tensor.get_shape().as_list() - - -def images_to_sequence(tensor): - """Convert a batch of images into a batch of sequences. - - Args: - tensor: a (num_images, height, width, depth) tensor - - Returns: - (width, num_images*height, depth) sequence tensor - """ - - num_image_batches, height, width, depth = _shape(tensor) - transposed = array_ops.transpose(tensor, [2, 0, 1, 3]) - return array_ops.reshape(transposed, - [width, num_image_batches * height, depth]) - - -def sequence_to_images(tensor, num_image_batches): - """Convert a batch of sequences into a batch of images. - - Args: - tensor: (num_steps, num_batches, depth) sequence tensor - num_image_batches: the number of image batches - - Returns: - (num_images, height, width, depth) tensor - """ - - width, num_batches, depth = _shape(tensor) - height = num_batches // num_image_batches - reshaped = array_ops.reshape(tensor, - [width, num_image_batches, height, depth]) - return array_ops.transpose(reshaped, [1, 2, 0, 3]) - - -def horizontal_lstm(images, num_filters_out, scope=None): - """Run an LSTM bidirectionally over all the rows of each image. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output depth - scope: optional scope name - - Returns: - (num_images, height, width, num_filters_out) tensor, where - num_steps is width and new num_batches is num_image_batches * height - """ - with variable_scope.variable_scope(scope, "HorizontalLstm", [images]): - batch_size, _, _, _ = _shape(images) - sequence = images_to_sequence(images) - with variable_scope.variable_scope("lr"): - hidden_sequence_lr = lstm1d.ndlstm_base(sequence, num_filters_out // 2) - with variable_scope.variable_scope("rl"): - hidden_sequence_rl = (lstm1d.ndlstm_base( - sequence, num_filters_out - num_filters_out // 2, reverse=1)) - output_sequence = array_ops.concat([hidden_sequence_lr, hidden_sequence_rl], - 2) - output = sequence_to_images(output_sequence, batch_size) - return output - - -def get_blocks(images, kernel_size): - """Split images in blocks - - Args: - images: (num_images, height, width, depth) tensor - kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of - of the pooling. Can be an int if both values are the same. - - Returns: - (num_images, height/kernel_height, width/kernel_width, - depth*kernel_height*kernel_width) tensor - """ - with variable_scope.variable_scope("image_blocks"): - batch_size, height, width, chanels = _shape(images) - - if height % kernel_size[0] != 0: - offset = array_ops.zeros([batch_size, - kernel_size[0] - (height % kernel_size[0]), - width, - chanels]) - images = array_ops.concat([images, offset], 1) - batch_size, height, width, chanels = _shape(images) - if width % kernel_size[1] != 0: - offset = array_ops.zeros([batch_size, - height, - kernel_size[1] - (width % kernel_size[1]), - chanels]) - images = array_ops.concat([images, offset], 2) - batch_size, height, width, chanels = _shape(images) - - h, w = int(height / kernel_size[0]), int(width / kernel_size[1]) - features = kernel_size[1] * kernel_size[0] * chanels - - lines = array_ops.split(images, h, axis=1) - line_blocks = [] - for line in lines: - line = array_ops.transpose(line, [0, 2, 3, 1]) - line = array_ops.reshape(line, [batch_size, w, features]) - line_blocks.append(line) - - return array_ops.stack(line_blocks, axis=1) - - -def separable_lstm(images, num_filters_out, - kernel_size=None, nhidden=None, scope=None): - """Run bidirectional LSTMs first horizontally then vertically. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of - of the pooling. Can be an int if both values are the same. Set to None for - not using blocks - nhidden: hidden layer depth - scope: optional scope name - - Returns: - (num_images, height/kernel_height, width/kernel_width, - num_filters_out) tensor - """ - with variable_scope.variable_scope(scope, "SeparableLstm", [images]): - if nhidden is None: - nhidden = num_filters_out - if kernel_size is not None: - images = get_blocks(images, kernel_size) - hidden = horizontal_lstm(images, nhidden) - with variable_scope.variable_scope("vertical"): - transposed = array_ops.transpose(hidden, [0, 2, 1, 3]) - output_transposed = horizontal_lstm(transposed, num_filters_out) - output = array_ops.transpose(output_transposed, [0, 2, 1, 3]) - return output - - -def reduce_to_sequence(images, num_filters_out, scope=None): - """Reduce an image to a sequence by scanning an LSTM vertically. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - scope: optional scope name - - Returns: - A (width, num_images, num_filters_out) sequence. - """ - with variable_scope.variable_scope(scope, "ReduceToSequence", [images]): - batch_size, height, width, depth = _shape(images) - transposed = array_ops.transpose(images, [1, 0, 2, 3]) - reshaped = array_ops.reshape(transposed, - [height, batch_size * width, depth]) - reduced = lstm1d.sequence_to_final(reshaped, num_filters_out) - output = array_ops.reshape(reduced, [batch_size, width, num_filters_out]) - return output - - -def reduce_to_final(images, num_filters_out, nhidden=None, scope=None): - """Reduce an image to a final state by running two LSTMs. - - Args: - images: (num_images, height, width, depth) tensor - num_filters_out: output layer depth - nhidden: hidden layer depth (defaults to num_filters_out) - scope: optional scope name - - Returns: - A (num_images, num_filters_out) batch. - """ - with variable_scope.variable_scope(scope, "ReduceToFinal", [images]): - nhidden = nhidden or num_filters_out - batch_size, height, width, depth = _shape(images) - transposed = array_ops.transpose(images, [1, 0, 2, 3]) - reshaped = array_ops.reshape(transposed, - [height, batch_size * width, depth]) - with variable_scope.variable_scope("reduce1"): - reduced = lstm1d.sequence_to_final(reshaped, nhidden) - transposed_hidden = array_ops.reshape(reduced, - [batch_size, width, nhidden]) - hidden = array_ops.transpose(transposed_hidden, [1, 0, 2]) - with variable_scope.variable_scope("reduce2"): - output = lstm1d.sequence_to_final(hidden, num_filters_out) - return output diff --git a/tensorflow/contrib/ndlstm/python/lstm2d_test.py b/tensorflow/contrib/ndlstm/python/lstm2d_test.py deleted file mode 100644 index f1b37d701b..0000000000 --- a/tensorflow/contrib/ndlstm/python/lstm2d_test.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for 2D LSTMs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import lstm2d as lstm2d_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -lstm2d = lstm2d_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class Lstm2DTest(test_util.TensorFlowTestCase): - - def testImagesToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.images_to_sequence(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (11, 14, 5)) - - def testSequenceToImagesDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(11, 14, 5)) - outputs = lstm2d.sequence_to_images(inputs, 2) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7, 11, 5)) - - def testImagesAndSequenceDims(self): - with self.test_session(): - size = (2, 7, 11, 5) - inputs = constant_op.constant(_rand(*size)) - sequence = lstm2d.images_to_sequence(inputs) - outputs = lstm2d.sequence_to_images(sequence, size[0]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), size) - - def testSeparableLstmDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.separable_lstm(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7, 11, 8)) - - def testSeparableLstmDimsBlocks(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.separable_lstm(inputs, 8, kernel_size=[2, 2]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 4, 6, 8)) - - def testReduceToSequenceDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.reduce_to_sequence(inputs, 8) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 11, 8)) - - def testReduceToFinalDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = lstm2d.reduce_to_final(inputs, 8, 12) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 8)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/ndlstm/python/misc.py b/tensorflow/contrib/ndlstm/python/misc.py deleted file mode 100644 index 38eeff84ca..0000000000 --- a/tensorflow/contrib/ndlstm/python/misc.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous functions useful for nD-LSTM models. - -Some of these functions duplicate functionality in tfslim with -slightly different interfaces. - -Tensors in this library generally have the shape (num_images, height, width, -depth). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.layers.python.layers import layers -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops - - -def _shape(tensor): - """Get the shape of a tensor as an int list.""" - return tensor.get_shape().as_list() - - -def pixels_as_vector(images, scope=None): - """Reduce images to vectors by combining all pixels.""" - with ops.name_scope(scope, "PixelsAsVector", [images]): - batch_size, height, width, depth = _shape(images) - return array_ops.reshape(images, [batch_size, height * width * depth]) - - -def pool_as_vector(images, scope=None): - """Reduce images to vectors by averaging all pixels.""" - with ops.name_scope(scope, "PoolAsVector", [images]): - return math_ops.reduce_mean(images, [1, 2]) - - -def one_hot_planes(labels, num_classes, scope=None): - """Compute 1-hot encodings for planes. - - Given a label, this computes a label image that contains - 1 at all pixels in the plane corresponding to the target - class and 0 in all other planes. - - Args: - labels: (batch_size,) tensor - num_classes: number of classes - scope: optional scope name - - Returns: - Tensor of shape (batch_size, 1, 1, num_classes) with a 1-hot encoding. - """ - with ops.name_scope(scope, "OneHotPlanes", [labels]): - batch_size, = _shape(labels) - batched = layers.one_hot_encoding(labels, num_classes) - return array_ops.reshape(batched, [batch_size, 1, 1, num_classes]) - - -def one_hot_mask(labels, num_classes, scope=None): - """Compute 1-hot encodings for masks. - - Given a label image, this computes the one hot encoding at - each pixel. - - Args: - labels: (batch_size, width, height, 1) tensor containing labels. - num_classes: number of classes - scope: optional scope name - - Returns: - Tensor of shape (batch_size, width, height, num_classes) with - a 1-hot encoding. - """ - with ops.name_scope(scope, "OneHotMask", [labels]): - height, width, depth = _shape(labels) - assert depth == 1 - sparse_labels = math_ops.to_int32(array_ops.reshape(labels, [-1, 1])) - sparse_size, _ = _shape(sparse_labels) - indices = array_ops.reshape(math_ops.range(0, sparse_size, 1), [-1, 1]) - concated = array_ops.concat([indices, sparse_labels], 1) - dense_result = sparse_ops.sparse_to_dense(concated, - [sparse_size, num_classes], 1.0, - 0.0) - result = array_ops.reshape(dense_result, [height, width, num_classes]) - return result diff --git a/tensorflow/contrib/ndlstm/python/misc_test.py b/tensorflow/contrib/ndlstm/python/misc_test.py deleted file mode 100644 index fac9023da3..0000000000 --- a/tensorflow/contrib/ndlstm/python/misc_test.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.ndlstm.python import misc as misc_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -misc = misc_lib - - -def _rand(*size): - return np.random.uniform(size=size).astype("f") - - -class LstmMiscTest(test_util.TensorFlowTestCase): - - def testPixelsAsVectorDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = misc.pixels_as_vector(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 7 * 11 * 5)) - - def testPoolAsVectorDims(self): - with self.test_session(): - inputs = constant_op.constant(_rand(2, 7, 11, 5)) - outputs = misc.pool_as_vector(inputs) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 5)) - - def testOneHotPlanes(self): - with self.test_session(): - inputs = constant_op.constant([0, 1, 3]) - outputs = misc.one_hot_planes(inputs, 4) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (3, 1, 1, 4)) - target = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) - self.assertAllClose(result.reshape(-1), target.reshape(-1)) - - def testOneHotMask(self): - with self.test_session(): - data = np.array([[0, 1, 2], [2, 0, 1]]).reshape(2, 3, 1) - inputs = constant_op.constant(data) - outputs = misc.one_hot_mask(inputs, 3) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (2, 3, 3)) - target = np.array([[[1, 0, 0], [0, 1, 0]], [[0, 1, 0], [0, 0, 1]], - [[0, 0, 1], [1, 0, 0]]]).transpose(1, 2, 0) - self.assertAllClose(result.reshape(-1), target.reshape(-1)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout.py b/tensorflow/contrib/nn/python/ops/alpha_dropout.py index d7b61a5844..2f92d05ba8 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout.py @@ -18,7 +18,6 @@ from __future__ import print_function import numbers -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -26,7 +25,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_impl def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py index 2ff978ab89..54a98e6f14 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout_test.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.nn.python.ops.alpha_dropout import alpha_dropout from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import nn_impl diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index b0a257d264..825c08a09a 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -21,12 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.opt.python.training import nadam_optimizer -from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 6a09f70f44..348623d8f8 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +# pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader +# pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD index 502720047e..4b7a4b16c7 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -22,6 +22,8 @@ py_library( "__init__.py", "context_managers.py", "misc.py", + "multiple_dispatch.py", + "type_check.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -48,3 +50,23 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "multiple_dispatch_test", + srcs = ["multiple_dispatch_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "type_check_test", + srcs = ["type_check_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py index 2ab0d7b9fd..1cbb0e0029 100644 --- a/tensorflow/contrib/py2tf/utils/__init__.py +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -20,3 +20,6 @@ from __future__ import print_function from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.py2tf.utils.misc import alias_tensors +from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond +from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while +from tensorflow.contrib.py2tf.utils.type_check import is_tensor diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch.py b/tensorflow/contrib/py2tf/utils/multiple_dispatch.py new file mode 100644 index 0000000000..d8a67255a4 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/multiple_dispatch.py @@ -0,0 +1,54 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for type-dependent behavior used in py2tf-generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils.type_check import is_tensor +from tensorflow.python.ops import control_flow_ops + + +def run_cond(condition, true_fn, false_fn): + if is_tensor(condition): + return control_flow_ops.cond(condition, true_fn, false_fn) + else: + return py_cond(condition, true_fn, false_fn) + + +def py_cond(condition, true_fn, false_fn): + if condition: + return true_fn() + else: + return false_fn() + + +def run_while(cond_fn, body_fn, init_args): + if not isinstance(init_args, (tuple, list)) or not init_args: + raise ValueError( + 'init_args must be a non-empty list or tuple, found %s' % init_args) + + if is_tensor(*init_args): + return control_flow_ops.while_loop(cond_fn, body_fn, init_args) + else: + return py_while_loop(cond_fn, body_fn, init_args) + + +def py_while_loop(cond_fn, body_fn, init_args): + state = init_args + while cond_fn(*state): + state = body_fn(*state) + return state diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py new file mode 100644 index 0000000000..5bb4d4086b --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multiple_dispatch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.contrib.py2tf.utils import multiple_dispatch +from tensorflow.python.client.session import Session +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.platform import test + + +class MultipleDispatchTest(test.TestCase): + + def test_run_cond_python(self): + true_fn = lambda: 2.0 + false_fn = lambda: 3.0 + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + + def test_run_cond_tf(self): + + true_fn = lambda: constant([2.0]) + false_fn = lambda: constant([3.0]) + with Session() as sess: + out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) + self.assertEqual(sess.run(out), 2.0) + out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) + self.assertEqual(sess.run(out), 3.0) + + def test_run_while_python(self): + cond_fn = lambda x, t, s: x > t + body_fn = lambda x, t, s: (x * s, t, s) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) + self.assertEqual(x, 0.75) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) + self.assertEqual(x, 3.0) + + def test_run_while_tf(self): + cond_fn = lambda x, t, s: x > t + body_fn = lambda x, t, s: (x * s, t, s) + + with Session() as sess: + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, + [constant(3.0), 1.0, 0.5]) + self.assertEqual(sess.run(x), 0.75) + + x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, + [constant(3.0), 4.0, 0.5]) + self.assertEqual(sess.run(x), 3.0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ndlstm/python/__init__.py b/tensorflow/contrib/py2tf/utils/type_check.py index 1aa51a6ec4..9ca2dec872 100644 --- a/tensorflow/contrib/ndlstm/python/__init__.py +++ b/tensorflow/contrib/py2tf/utils/type_check.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Init file, giving convenient access to all ndlstm ops.""" +"""Utilities used in py2tf-generated code.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import,g-importing-member -from tensorflow.contrib.ndlstm.python.lstm1d import * -from tensorflow.contrib.ndlstm.python.lstm2d import * -from tensorflow.contrib.ndlstm.python.misc import * -# pylint: enable=wildcard-import +from tensorflow.python.framework import tensor_util + + +def is_tensor(*args): + """Check if all arguments are tensors. + + Args: + *args: Python objects that may or may not be tensors. + + Returns: + True if all *args are TensorFlow types, False if one or more are not. + """ + return any([tensor_util.is_tensor(a) for a in args]) diff --git a/tensorflow/contrib/ndlstm/__init__.py b/tensorflow/contrib/py2tf/utils/type_check_test.py index da89bb4ab6..7d0428e9cc 100644 --- a/tensorflow/contrib/ndlstm/__init__.py +++ b/tensorflow/contrib/py2tf/utils/type_check_test.py @@ -12,10 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Tests for type_check.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.ndlstm.python import lstm2d -from tensorflow.contrib.ndlstm.python import lstm1d +import numpy + +from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TypeCheckTest(test.TestCase): + + def test_checks(self): + self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3]))) + self.assertTrue( + type_check.is_tensor(test_util.variables.Variable([1, 2, 3]))) + self.assertTrue( + type_check.is_tensor( + test_util.array_ops.placeholder(test_util.dtypes.float32))) + self.assertFalse(type_check.is_tensor(3)) + self.assertFalse(type_check.is_tensor(numpy.eye(3))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b7d525a1fa..ada336e623 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -75,7 +75,9 @@ py_library( ":graph_matcher", ":input_to_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:layers", "//tensorflow/python:math_ops", @@ -83,6 +85,7 @@ py_library( "//tensorflow/python:nn_ops", "//tensorflow/python:ops", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variables", ], ) @@ -162,7 +165,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variables", @@ -174,7 +176,7 @@ py_library( srcs = ["python/quantize.py"], srcs_version = "PY2AND3", deps = [ - ":common", + ":graph_matcher", ":input_to_ops", ":quant_ops", "//tensorflow/contrib/graph_editor:graph_editor_py", diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index f80d427ff0..0a8e35080c 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -53,7 +53,7 @@ def LastValueQuantize(inputs, init_max=6.0, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='LastValueQuant', reuse=None, is_training=True, num_bits=8, @@ -73,7 +73,7 @@ def LastValueQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -84,13 +84,13 @@ def LastValueQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'LastValueQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] @@ -165,7 +165,7 @@ def MovingAvgQuantize(inputs, ema_decay=0.999, updates_collection=ops.GraphKeys.UPDATE_OPS, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - scope=None, + name_prefix='MovingAvgQuantize', reuse=None, is_training=True, num_bits=8, @@ -186,7 +186,7 @@ def MovingAvgQuantize(inputs, computation. vars_collection: (Optional) collection where to store variables for quantization interval ends. - scope: Optional scope for variable_scope. + name_prefix: name_prefix for created nodes. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. is_training: Whether the op is applied to a training or eval graph. @@ -197,13 +197,13 @@ def MovingAvgQuantize(inputs, a tensor containing quantized values. """ with variable_scope.variable_scope( - scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse): + None, default_name=name_prefix, values=[inputs], reuse=reuse): input_shape = inputs.get_shape() input_dim = len(input_shape) if per_channel: # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' - ' scope: %s' % (input_shape, scope)) + ' scope: %s' % (input_shape, name_prefix)) min_max_shape = [input_shape[-1]] else: min_max_shape = [] diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 50a2b4c91c..1a63b0a2ce 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logic to update a Tensorflow model graph with quantization operations.""" +"""Logic to update a TensorFlow model graph with quantization operations.""" from __future__ import absolute_import from __future__ import division @@ -20,7 +20,7 @@ from __future__ import print_function import re from tensorflow.contrib import graph_editor -from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import ops @@ -28,30 +28,29 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import training_util -# Operation types used to select operations of interest. +# Quantizable operation types that are supported by the quantization rewrite. _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} -# Custom key for storing and retrieving update ops used by quantizing nodes. -_UPDATE_QUANT_OPS = 'update_quant_ops' +# Activations that are supported by the quantization rewrite. +_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} + +# Weight types that are supported by the quantization rewrite. +# TODO(suharshs): Add support for ResourceVariable. +_WEIGHT_TYPES = {'Variable', 'VariableV2'} def Quantize(graph, weight_bits=8, - weight_narrow_range=False, activation_bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): + is_training=True): """Updates graph with quantization operations. Args: graph: Graph to modify. weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. activation_bits: Number of bits to use for quantizing activations. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: @@ -62,345 +61,274 @@ def Quantize(graph, vars_collection: (Optional) Collection where to store the variables for quantization interval ends. is_training: (Optional) Whether quantizing training graph or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. Raises: ValueError: When quantization fails. """ - context = _QuantizeContext(graph, weight_bits, weight_narrow_range, - activation_bits, ema_decay, quant_delay, - vars_collection, is_training, - quantize_folded_weights_use_ema) - - graph_ops = graph.get_operations() - - # Filter out backprop and summary related operations, leave only interesting - # op types. - def _IsInterestingOpWithWeights(op): - return (op.type in _QUANTIZABLE_TYPES and - not op.name.startswith(common.SKIPPED_PREFIXES)) - - for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)): - if op.name.endswith('/depthwise'): - # Separable convolution may consist of 2 convolution nodes. If so, skip - # .../depthwise and only quantize the top one. - separable_conv = context.GetOperationByNameDontThrow( - op.name[:-len('/depthwise')]) - if separable_conv and separable_conv.type == 'Conv2D': - continue - # Quantize add ops that come after Conv2D or DepthwiseConv2dNative. - if op.type in ['Conv2D', 'DepthwiseConv2dNative']: - add_context_re = re.search(r'^(.*)/[^/]+/', op.name) - if add_context_re is not None: - context.add_contexts.add(add_context_re.group(1)) - if not op.name.endswith('_Fold'): - folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold') - # Do nothing if found, it will be quantized when it is iterated over. - if not folded_op: - context.QuantizeOpWithWeights(op, folded=False) - else: - context.QuantizeOpWithWeights(op, folded=True) - - context.QuantizeAddContexts() - - # Once all quantization ops have been inserted in the graph, collect update - # ops for their variables and modify the TF Slim update barrier (see - # https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py) - # to depend on them. - try: - update_barrier = graph.get_operation_by_name('update_barrier') - except KeyError: - # In evaluation graph, this barrier may not exist. - return None - update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS) - graph_editor.add_control_inputs(update_barrier, update_quant_ops) - - -class _QuantizeContext(object): - """Context holds references needed for quantization.""" - - def __init__(self, - graph, - weight_bits, - weight_narrow_range, - activation_bits, - ema_decay=0.999, - quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, - quantize_folded_weights_use_ema=False): - """Initializes context to hold references needed for quantization. - - Args: - graph: Graph to modify. - weight_bits: Number of bits to use for quantizing weights. - weight_narrow_range: Whether to use a more efficient narrow range for - weights quantization. With weight_narrow_range true, the range is - [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. - activation_bits: Number of bits to use for quantizing activations. - ema_decay: (Optional) Float, EMA decay parameter. - quant_delay: (Optional, default None) Int, count of global steps for which - to delay quantization. This helps weights stabilize at the start of - training. - vars_collection: (Optional) Collection where to store the variables for - quantization interval ends. - is_training: (Optional) Whether quantizing training or eval graph. - quantize_folded_weights_use_ema: (Optional, default False) Whether to - quantize weights after batchnorm-folding with exponential average - quantization. - """ - self.graph = graph - self.weight_bits = weight_bits - self.weight_narrow_range = weight_narrow_range - self.activation_bits = activation_bits - self.ema_decay = ema_decay - self.quant_delay = quant_delay - self.vars_collection = vars_collection - self.is_training = is_training - self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema - self.input_to_ops_map = input_to_ops.InputToOps(graph) - self.add_contexts = set() - - def QuantizeAddContexts(self): - """Quantizes all add ops in self.add_contexts.""" - # Loop through sorted self.add_contexts so that op creation is - # deterministic. This is needed when using multiple worker replicas so that - # the ops can be initialized consistently. - for add_context in sorted(self.add_contexts): - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - if add_op is not None: - self._InsertQuantOp( - add_context, - add_op, - self.input_to_ops_map.ConsumerOperations(add_op), - name='add_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - - def QuantizeOpWithWeights(self, op, folded): - """Quantizes around the specific operation with or without batch norm. - - Args: - op: Operation to quantize. - folded: Operation has been folded and needs special handling if True. - Raises: - ValueError: When quantization fails. - """ - # Op name component before the last slash will be used as context. - context = re.search(r'^(.*)/([^/]+)', op.name).group(1) - - # Quantize weights. - if folded: - producer_op = self.graph.get_operation_by_name(context + '/mul_fold') - else: - try: - input_idx = next(i for i, v in enumerate(op.inputs) - if '/weights/' in v.name or - '/depthwise_weights' in v.name) - except StopIteration: - raise ValueError('No inputs to quantize for op: %s' % op) - producer_op = op.inputs[input_idx].op - - # If batch norm is used, the folded weights depend on the batch std, hence - # it is sensible to use EMA during training to smooth out the noise. This is - # controlled by the flag quantize_folded_weights_use_ema. Its default is - # False for backward compatibility. - # If there is no batch norm, weights do not depend on the batch and using - # the latest value of min and max is more efficient. - weight_use_ema = folded and self.quantize_folded_weights_use_ema - self._InsertQuantOp( + input_to_ops_map = input_to_ops.InputToOps(graph) + for layer_match in _FindLayersToQuantize(graph): + # Quantize the weights. + context = _GetContextFromOp(layer_match.layer_op) + _InsertQuantOp( context, - producer_op, [op], + layer_match.weight_tensor.op, [layer_match.layer_op], name='weights_quant', - moving_avg=weight_use_ema, - delay_requested=weight_use_ema, - bits=self.weight_bits, - narrow_range=self.weight_narrow_range) - - # Important: do not quantize biases here. During inference they are - # quantized to 32 bits, which is much finer than 8 bit quantization and - # depends on weight and input activation ranges. - - # Find activation and (optionally) Add operations to quantize. - activation_op, add_op, add_context = self._GetReluAndAddOperations(context, - op) - if add_op: - original_context = context - context = add_context - - # Quantize activation outputs. - consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op) - self._InsertQuantOp( - context, - activation_op, + moving_avg=False, + bits=weight_bits, + ema_decay=ema_decay, + quant_delay=quant_delay, + is_training=is_training, + narrow_range=True, + vars_collection=vars_collection) + + # Quantize the activations. + consumer_ops = input_to_ops_map.ConsumerOperations( + layer_match.activation_op) + add_context = context + if layer_match.bypass_op: + add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + _InsertQuantOp( + add_context, + layer_match.activation_op, consumer_ops, name='act_quant', moving_avg=True, init_min=0.0, - bits=self.activation_bits, - narrow_range=False) - - # When a bypass connection was found, also quantize Add op input. - if add_op: - def _QuantizeAddInput(add_input): - if folded: - return add_input.op.name.endswith('/add_fold') - else: - return add_input.op.name.startswith(original_context + '/') - - for add_input in add_op.inputs: - if _QuantizeAddInput(add_input): - self._InsertQuantOp( - original_context, - add_input.op, [add_op], - name='conv_quant', - moving_avg=True, - bits=self.activation_bits, - narrow_range=False) - - def _GetReluAndAddOperations(self, context, op): - """Looks up a Relu* and Add operations in given context. - - Args: - context: Context where to look for operations. - op: Operation to quantize. - - Returns: - A triplet (Operation, Operation, string), the first element is an end - point operation, the second is Add operation (optional), the third element - is string context where the Add operation was found (optional). - - Raises: - ValueError: When operations cannot be found. - """ - activation_op = common.GetEndpointActivationOp(self.graph, context) - if activation_op: - return activation_op, None, None - - if '/' in context: - # If no activation op is there, look for them one level up. - add_context = re.search(r'^(.*)/([^/]+)', context).group(1) - activation_op = common.GetEndpointActivationOp(self.graph, add_context) - if not activation_op: - # Still no Relu, can happen on the top layer, just find the next node up, - # make sure it is BiasAdd. - consumers = [c for outp in op.outputs for c in outp.consumers()] - if len(consumers) != 1 or consumers[0].type != 'BiasAdd': - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) - return consumers[0], None, None - if add_context: - add_op = self.GetOperationByNamesDontThrow([ - add_context + '/Add', add_context + '/add']) - return activation_op, add_op, add_context - else: - raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) - - def GetOperationByNameDontThrow(self, name): - """Returns an Operation with the given name. - - Args: - name: Name of Operation to return. - - Returns: - The Operation with the given name. None if the name does not correspond to - any operation in the graph. - """ - try: - return self.graph.get_operation_by_name(name) - except KeyError: - return None - - def GetOperationByNamesDontThrow(self, names): - """Returns an Operation with one of the given names. - - Args: - names: Names of Operation to return. - - Returns: - The Operation with one of the given names. None if none of the names - corresponds to any operation in the graph. - """ - for name in names: - op = self.GetOperationByNameDontThrow(name) - if op is not None: - return op - return None - - def _InsertQuantOp( - self, - context, - producer, - consumers, - name, - moving_avg=True, - init_min=-6.0, - init_max=6.0, - delay_requested=True, - bits=8, - narrow_range=False,): - """Inserts a quant op between a producer op and (multiple) consumer ops. - - Args: - context: Context where producer and consumer operations are nested. - producer: Producer operation of the pairs where quantization will be - inserted. - consumers: Consumer operations of the pairs. - name: Name for the new quantization op within the context. - moving_avg: Specifies whether to use exponential moving average or just - the last value seen. - init_min: Starting minimum value for the new quantization op. - init_max: Starting maximum value for the new quantization op. - delay_requested: If true, implement quantization delay where needed. - False value explicitly disables delay quantization everywhere. - bits: Number of bits to use for quantization, must be between 2 and 8. - narrow_range: Whether to use the narrow quantization range + ema_decay=ema_decay, + quant_delay=quant_delay, + bits=activation_bits, + vars_collection=vars_collection) + + # Quantize the inputs and output to the bypass (if it exists). The input to + # the bypass is the bias add, and the output is the activation. + if layer_match.bypass_op is not None: + _InsertQuantOp( + context, + layer_match.bias_add_op, [layer_match.bypass_op], + name='conv_quant', + moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, + bits=activation_bits) + _InsertQuantOp( + add_context, + layer_match.bypass_op, + input_to_ops_map.ConsumerOperations(layer_match.bypass_op), + name='add_quant', + moving_avg=True, + bits=activation_bits) + + +def _FindLayersToQuantize(graph): + """Matches layers in graph to quantize. + + Args: + graph: Graph to perform match on. + + Yields: + _LayerMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_var_pattern = graph_matcher.OpTypePattern('|'.join(_WEIGHT_TYPES)) + weight_pattern = graph_matcher.OpTypePattern( + 'Identity', inputs=[weight_var_pattern]) + + folded_weight_pattern = graph_matcher.OpTypePattern('Mul') + + # The weights inputs to the layer operation can either be from the Variable or + # the folded weight (Mul). + layer_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + input_pattern, + graph_matcher.OneofPattern([weight_pattern, folded_weight_pattern]) + ]) + + folded_bias_mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_pattern]) + post_layer_op_correction_pattern = graph_matcher.OpTypePattern( + 'Add', inputs=[folded_bias_mul_pattern, + graph_matcher.OpTypePattern('*')]) + folded_bias_add_pattern = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + post_layer_op_correction_pattern, + graph_matcher.OpTypePattern('*') + ]) + + bias_add_pattern = graph_matcher.OpTypePattern( + 'Add|BiasAdd', inputs=[layer_pattern, '*']) + + # The bias can come from the bias add or the folded bias add. + bypass_pattern_a = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]), '*' + ]) + bypass_pattern_b = graph_matcher.OpTypePattern( + 'Add', + inputs=[ + '*', + graph_matcher.OneofPattern( + [bias_add_pattern, folded_bias_add_pattern]) + ]) + + # The input to the activation can come from bias add, fold bias add or the + # bypasses. + activation_pattern = graph_matcher.OpTypePattern( + '|'.join(_ACTIVATION_TYPES), + inputs=[ + graph_matcher.OneofPattern([ + bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a, + bypass_pattern_b + ]) + ]) + + layer_matcher = graph_matcher.GraphMatcher(activation_pattern) + for match_result in layer_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + if weight_tensor is None: + weight_tensor = match_result.get_tensor(folded_weight_pattern) + activation_op = match_result.get_op(activation_pattern) + bias_add_op = match_result.get_op(bias_add_pattern) + if bias_add_op is None: + bias_add_op = match_result.get_op(folded_bias_add_pattern) + bypass_op = match_result.get_op(bypass_pattern_a) + if bypass_op is None: + bypass_op = match_result.get_op(bypass_pattern_b) + yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op) + + +class _LayerMatch(object): + """Contains all information related to a matched Layer.""" + + def __init__(self, layer_op, weight_tensor, activation_op, bypass_op, + bias_add_op): + self._layer_op = layer_op + self._weight_tensor = weight_tensor + self._activation_op = activation_op + self._bypass_op = bypass_op + self._bias_add_op = bias_add_op + + @property + def layer_op(self): + return self._layer_op + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def activation_op(self): + return self._activation_op + + @property + def bypass_op(self): + return self._bypass_op + + @property + def bias_add_op(self): + return self._bias_add_op + + +def _InsertQuantOp(context, + producer, + consumers, + name, + moving_avg=True, + init_min=-6.0, + init_max=6.0, + bits=8, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + narrow_range=False): + """Inserts a quant op between a producer op and (multiple) consumer ops. + + Args: + context: Context w,here producer and consumer operations are nested. + producer: Producer operation of the pairs where quantization will be + inserted. + consumers: Consumer operations of the pairs. + name: Name for the new quantization op within the context. + moving_avg: Specifies whether to use exponential moving average or just + the last value seen. + init_min: Starting minimum value for the new quantization op. + init_max: Starting maximum value for the new quantization op. + bits: Number of bits to use for quantization, must be between 2 and 8. + ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update + quantization intervals for quantizing activations (see here about EMA: + https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training graph or eval graph. + narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. - Raises: - ValueError: When producer operation is not directly connected to the - consumer operation. - """ - scope = context + '/' + name - inputs = producer.outputs[0] - if moving_avg: - quant = (quant_ops.MovingAvgQuantize( - inputs, - init_min=init_min, - init_max=init_max, - ema_decay=self.ema_decay, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) - else: - quant = (quant_ops.LastValueQuantize( - inputs, - init_min=init_min, - init_max=init_max, - is_training=self.is_training, - num_bits=bits, - narrow_range=narrow_range, - updates_collection=_UPDATE_QUANT_OPS, - vars_collection=self.vars_collection, - scope=scope)) - - if delay_requested and self.quant_delay and self.quant_delay > 0: - activate_quant = math_ops.greater_equal( - training_util.get_or_create_global_step(), - self.quant_delay, - name=scope + '/activate_quant') - quant = control_flow_ops.cond( - activate_quant, - lambda: quant, - lambda: inputs, - name=scope + '/delayed_quant') - - nodes_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) - if nodes_modified_count != len(consumers): - raise ValueError('Some inputs not quantized for ops: [%s]' % - ', '.join([consumer.name for consumer in consumers])) + Raises: + ValueError: When producer operation is not directly connected to the + consumer operation. + """ + name_prefix = _AddContextToName(context, name) + inputs = producer.outputs[0] + if moving_avg: + quant = ( + quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=ema_decay, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + else: + quant = ( + quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=is_training, + num_bits=bits, + narrow_range=narrow_range, + vars_collection=vars_collection, + name_prefix=name_prefix)) + + if quant_delay and quant_delay > 0: + activate_quant = math_ops.greater_equal( + training_util.get_or_create_global_step(), + quant_delay, + name=name_prefix + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=name_prefix + '/delayed_quant') + + nodes_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + if nodes_modified_count != len(consumers): + raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( + [consumer.name for consumer in consumers])) + + +def _GetContextFromOp(op): + """Gets the root context name from the op name.""" + context_re = re.search(r'^(.*)/([^/]+)', op.name) + if context_re: + return context_re.group(1) + return '' + + +def _AddContextToName(context, name): + """Adds the context to the name if it exists.""" + if not context: + return name + return context + '/' + name diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 57dab03f16..f1fe322049 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -101,7 +101,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/Conv2D' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/Conv2D') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -176,7 +178,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/MatMul' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/MatMul') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -252,7 +256,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope + '/depthwise_weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) - output_op_name = scope + '/depthwise' + output_op_name = ( + scope + '/weights_quant/delayed_quant/Switch_1' + if delay else scope + '/depthwise') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -316,40 +322,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): for params in parameters_list: test_fn(params[0], params[1], params[2], params[3], params[4]) - def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): - """Tests quantization: inputs -> Conv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_Conv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_Conv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) - def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> Conv2d with batch norm -> Activation. Args: @@ -360,7 +337,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): @@ -394,23 +370,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if (delay and use_ema) else '/Conv2D_Fold') + if delay else '/Conv2D_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -438,40 +410,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) - def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm): - """Tests quantization: inputs -> FC with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_FCWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_FCWithBatchNorm(self): self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) - def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, - with_bypass, delay, fused_batch_norm, - use_ema): + def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): """Tests quantization: inputs -> FC with batch norm -> Activation. Args: @@ -482,7 +425,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): @@ -513,23 +455,19 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/MatMul_Fold') + if delay else '/MatMul_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: @@ -557,42 +495,13 @@ class QuantizeTest(test_util.TensorFlowTestCase): if delay else 'control_dependency') self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) - def _TestQuantize_DepthwiseConv2dWithBatchNorm( - self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm): - """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. - - Args: - activation: Callable that returns an Operation, a factory method for the - Activation. - activation_op_name: String, name of the Activation operation. - with_bypass: Bool, when true there is an extra connection added from - inputs to just before Activation. - delay: Int (optional), delay in number of steps until quantization starts. - fused_batch_norm: Bool, when true use FusedBatchNorm. - """ - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=True) - self._testQuantize_DepthwiseConv2dWithBatchNorm( - activation, - activation_op_name, - with_bypass, - delay, - fused_batch_norm, - use_ema=False) - def testQuantize_DepthwiseConv2dWithBatchNorm(self): self._RunBatchNormTestOverParameters( self._TestQuantize_DepthwiseConv2dWithBatchNorm) - def _testQuantize_DepthwiseConv2dWithBatchNorm( + def _TestQuantize_DepthwiseConv2dWithBatchNorm( self, activation, activation_op_name, with_bypass, delay, - fused_batch_norm, use_ema): + fused_batch_norm): """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. Args: @@ -603,7 +512,6 @@ class QuantizeTest(test_util.TensorFlowTestCase): inputs to just before Activation. delay: Int (optional), delay in number of steps until quantization starts. fused_batch_norm: Bool, when true use FusedBatchNorm. - use_ema: Bool, when true uses EMA quantization for BN folded weights. """ graph = ops.Graph() with graph.as_default(): @@ -637,22 +545,18 @@ class QuantizeTest(test_util.TensorFlowTestCase): fold_batch_norms.FoldBatchNorms(graph) - quantize.Quantize( - graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantize.Quantize(graph, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('AssignMinEma' - if use_ema else 'AssignMinLast'), - scope + '/weights_quant/' + ('AssignMaxEma' - if use_ema else 'AssignMaxLast'), - scope + '/mul_fold' + scope + '/weights_quant/' + 'AssignMinLast', + scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' - if delay and use_ema else '/depthwise_Fold') + if delay else '/depthwise_Fold') self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) if with_bypass: diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 1e4dd7cf67..53cbd66741 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -45,13 +45,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): activation_fn=None, scope='test') relu = nn_ops.relu6(inputs) - context = quantize._QuantizeContext(graph=graph, weight_bits=8, - weight_narrow_range=True, - activation_bits=8) # Inserting a quantization op between two unconnected ops should fail with # ValueError. with self.assertRaises(ValueError) as err: - context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') self.assertEqual( str(err.exception), 'Some inputs not quantized for ops: [Relu6]') @@ -70,8 +67,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + @@ -94,8 +90,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, - activation_bits=8) + quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + diff --git a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py index 60a193db4c..468886da20 100644 --- a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py +++ b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import unittest from tensorflow.contrib.reduce_slice_ops.python.ops import reduce_slice_ops from tensorflow.python.framework.test_util import TensorFlowTestCase diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 9b84635e85..0e62b315b6 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -39,8 +39,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.framework import test_util -from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -167,9 +165,10 @@ class RNNCellTest(test.TestCase): m = array_ops.zeros([1, 2]) g, _ = contrib_rnn_cell.SRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g], {x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]])}) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) # Smoke test self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 6af9db3f15..fe07493d0f 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl # pylint: disable=unused-import from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import nn_impl from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest diff --git a/tensorflow/contrib/session_bundle/bundle_shim.py b/tensorflow/contrib/session_bundle/bundle_shim.py index 69db594f8a..1db97020a2 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim.py +++ b/tensorflow/contrib/session_bundle/bundle_shim.py @@ -134,9 +134,8 @@ def _convert_named_signatures_to_signature_def(signatures): signature_constants.PREDICT_OUTPUTS] # TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once # it is submitted. - if (input_signature.WhichOneof("type") != - legacy_constants.GENERIC_SIGNATURE or - output_signature.WhichOneof("type") != + if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE + or output_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE): raise RuntimeError("Named input and output signatures can only be " "up-converted if they are generic signature. " diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py index 249c23c88f..514cc0f652 100644 --- a/tensorflow/contrib/session_bundle/gc.py +++ b/tensorflow/contrib/session_bundle/gc.py @@ -70,7 +70,6 @@ import heapq import math import os -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.platform import gfile from tensorflow.python.util.deprecation import deprecated diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index f5a9299d26..8a267ddac7 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -29,7 +29,6 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib -from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 83f33806e0..6a200de1ea 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -738,7 +738,7 @@ def train(train_op, if summary_writer is not None: train_step_kwargs['summary_writer'] = sv.summary_writer - total_loss = 0 + total_loss = None should_retry = True while should_retry: try: @@ -771,10 +771,10 @@ def train(train_op, logging.info('Stopping Training.') sv.request_stop() break - except errors.OutOfRangeError: + except errors.OutOfRangeError as e: # OutOfRangeError is thrown when epoch limit per # tf.train.limit_epochs is reached. - logging.info('Caught OutOfRangeError. Stopping Training.') + logging.info('Caught OutOfRangeError. Stopping Training. %s', e) if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index 7b609ae96b..a1282847be 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -47,8 +47,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): a_np = np.dot(a_np.T, a_np) # jacobi preconditioner jacobi_np = np.zeros_like(a_np) - jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = (1.0 / - a_np.diagonal()) + jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = ( + 1.0 / a_np.diagonal()) rhs_np = np.random.uniform( low=-1.0, high=1.0, size=shape_[0]).astype(dtype_) x_np = np.zeros_like(rhs_np) @@ -66,18 +66,30 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): x = array_ops.placeholder(dtype_) jacobi = array_ops.placeholder(dtype_) operator = util.create_operator(a) - preconditioners = [None, util.identity_operator(a), - util.create_operator(jacobi)] + preconditioners = [ + None, util.identity_operator(a), + util.create_operator(jacobi) + ] cg_results = [] for preconditioner in preconditioners: cg_graph = linear_equations.conjugate_gradient( - operator, rhs, preconditioner=preconditioner, - x=x, tol=tol, max_iter=max_iter) + operator, + rhs, + preconditioner=preconditioner, + x=x, + tol=tol, + max_iter=max_iter) if use_static_shape_: cg_val = sess.run(cg_graph) else: - cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np, x: x_np, - jacobi: jacobi_np}) + cg_val = sess.run( + cg_graph, + feed_dict={ + a: a_np, + rhs: rhs_np, + x: x_np, + jacobi: jacobi_np + }) norm_r0 = np.linalg.norm(rhs_np) norm_r = np.linalg.norm(cg_val.r) self.assertLessEqual(norm_r, tol * norm_r0) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py index 12e94369cb..5d7534657b 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py @@ -85,9 +85,11 @@ class UtilTest(test.TestCase): op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty]) else: op_shape_val, ax_val, aty_val = sess.run( - [op_shape, ax, aty], feed_dict={a: a_np, - x: x_np, - y: y_np}) + [op_shape, ax, aty], feed_dict={ + a: a_np, + x: x_np, + y: y_np + }) self.assertAllEqual(op_shape_val, [3, 2]) self.assertAllClose(ax_val, x_np) self.assertAllClose(aty_val, y_np) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 4dfaa97ac9..d791d46763 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import linalg_ops @@ -84,10 +85,9 @@ def conjugate_gradient(operator, cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"]) def stopping_criterion(i, state): - return math_ops.logical_and(i < max_iter, - linalg_ops.norm(state.r) > tol) + return math_ops.logical_and(i < max_iter, linalg_ops.norm(state.r) > tol) - def cg_step(i, state): + def cg_step(i, state): # pylint: disable=missing-docstring z = operator.apply(state.p) alpha = state.gamma / util.dot(state.p, z) x = state.x + alpha * state.p @@ -108,8 +108,7 @@ def conjugate_gradient(operator, rhs = array_ops.expand_dims(rhs, -1) if x is None: x = array_ops.expand_dims( - array_ops.zeros( - n, dtype=rhs.dtype.base_dtype), -1) + array_ops.zeros(n, dtype=rhs.dtype.base_dtype), -1) r0 = rhs else: x = array_ops.expand_dims(x, -1) @@ -119,7 +118,7 @@ def conjugate_gradient(operator, else: p0 = preconditioner.apply(r0) gamma0 = util.dot(r0, p0) - tol = tol * linalg_ops.norm(r0) + tol *= linalg_ops.norm(r0) i = constant_op.constant(0, dtype=dtypes.int32) state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0) _, state = control_flow_ops.while_loop(stopping_criterion, cg_step, diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index 73a5cf1e92..890ca20f4c 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -23,7 +23,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.platform import resource_loader __all__ = ["sparsemax"] diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py index ba18f89e16..582d1e6136 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD index 4b688690ae..084953a0a2 100644 --- a/tensorflow/contrib/specs/BUILD +++ b/tensorflow/contrib/specs/BUILD @@ -23,7 +23,6 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/ndlstm", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:logging_ops", diff --git a/tensorflow/contrib/specs/README.md b/tensorflow/contrib/specs/README.md index b764e6e714..bcf34e601f 100644 --- a/tensorflow/contrib/specs/README.md +++ b/tensorflow/contrib/specs/README.md @@ -59,17 +59,6 @@ Reshaping: - `Squeeze` = tf.squeeze - `Expand` = tf.expand_dims -Multidimensional LSTM: - -These are intended as alternatives to 2D convolutions. For sequence models, -there will be other modeling primitives. - - - `Lstm2` = Fun(lstm2d.separable_lstm) # 2D-to-2D - - `Lstm2to1` = Fun(lstm2d.reduce_to_sequence) # 2D-to-1D - - `Lstm2to0` = Fun(lstm2d.reduce_to_final) # 2D-to-vector - - `Clstm2(n, m)` is a `Cl(n, [3,3])` followed by `Lstm2(m)` - - `Dws(n)` is a depthwise convolution `Cs(n, [1, 1])` - Other: - `Id` = identity diff --git a/tensorflow/contrib/specs/python/specs_ops.py b/tensorflow/contrib/specs/python/specs_ops.py index a6bd4d16c2..49b989b8d0 100644 --- a/tensorflow/contrib/specs/python/specs_ops.py +++ b/tensorflow/contrib/specs/python/specs_ops.py @@ -23,8 +23,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers -from tensorflow.contrib.ndlstm.python import lstm1d -from tensorflow.contrib.ndlstm.python import lstm2d from tensorflow.contrib.specs.python import specs_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -122,17 +120,6 @@ Sig = Fun(math_ops.sigmoid) Tanh = Fun(math_ops.tanh) Smax = Fun(nn_ops.softmax) -# 2D LSTM - -Lstm2 = Fun(lstm2d.separable_lstm) -Lstm2to1 = Fun(lstm2d.reduce_to_sequence) # 2D to 1D -Lstm2to0 = Fun(lstm2d.reduce_to_final) # 2D to depth-only - - -def Clstm2(n, *args, **kw): - """2D LSTM with 3x3 pre-convolution.""" - return Cl(n, [3, 3]) | Lstm2(*args, **kw) - def Dws(n): """Depth-wise convolution + sigmoid (used after LSTM).""" @@ -143,13 +130,6 @@ def Dwm(n): """Depth-wise convolution + softmax (used after LSTM).""" return Cm(n, [1, 1]) - -# 1D LSTM - -Lstm1 = Fun(lstm1d.ndlstm_base) -Lstm1to0 = Fun(lstm1d.sequence_to_final) # 1D to depth-only -Ssm = Fun(lstm1d.sequence_softmax) - # Sharing of Variables diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py index 41782a9fc9..9a4ad36793 100644 --- a/tensorflow/contrib/specs/python/specs_test.py +++ b/tensorflow/contrib/specs/python/specs_test.py @@ -149,36 +149,6 @@ class SpecsTest(test.TestCase): self.assertEqual(tuple(result.shape), (10, 20)) self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig") - def testLstm2(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 64, 64, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 64, 64, 15)) - - def testLstm2to1(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2to1(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 64, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 64, 15)) - - def testLstm2to0(self): - with self.test_session(): - inputs = constant_op.constant(_rand(1, 64, 64, 5)) - spec = "net = Lstm2to0(15)" - outputs = specs.create_net(spec, inputs) - self.assertEqual(outputs.get_shape().as_list(), [1, 15]) - variables.global_variables_initializer().run() - result = outputs.eval() - self.assertEqual(tuple(result.shape), (1, 15)) - def testKeywordRestriction(self): with self.test_session(): inputs = constant_op.constant(_rand(10, 20)) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 7d3b8b7437..2d6d7ea6a3 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -18,6 +18,42 @@ The operations in this package are safe to use with eager execution turned on or off. It has a more flexible API that allows summaries to be written directly from ops to places other than event log files, rather than propagating protos from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. + +To use with eager execution enabled, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate a record + # ... + +To use it with graph execution, write your code as follows: + +global_step = tf.train.get_or_create_global_step() +summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) +with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): + # model definition code goes here + # and in it call + tf.contrib.summary.scalar("loss", my_loss) + # In this case every call to tf.contrib.summary.scalar will generate an op, + # note the need to run tf.contrib.summary.all_summary_ops() to make sure these + # ops get executed. + # ... + train_op = .... + +with tf.Session(...) as sess: + tf.global_variables_initializer().run() + tf.contrib.summary.initialize(graph=tf.get_default_graph()) + # ... + while not_done_training: + sess.run([train_op, tf.contrib.summary.all_summary_ops()]) + # ... + """ from __future__ import absolute_import diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index a6968d8b2a..068ae35c71 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -154,10 +154,12 @@ def initialize( to @{tf.get_default_session}. Raises: - RuntimeError: If in eager mode, or if the current thread has no - default @{tf.contrib.summary.SummaryWriter}. + RuntimeError: If the current thread has no default + @{tf.contrib.summary.SummaryWriter}. ValueError: If session wasn't passed and no default session. """ + if context.in_eager_mode(): + return if context.context().summary_writer_resource is None: raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") if session is None: @@ -292,13 +294,9 @@ def all_summary_ops(): Returns: The summary ops. - - Raises: - RuntimeError: If in Eager mode. """ if context.in_eager_mode(): - raise RuntimeError( - "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return None return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc index 28417b89e0..f8de8baa65 100644 --- a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -212,4 +212,20 @@ An op that shuts down a running distributed TPU system. The Op returns an error if no system is running. )doc"); -} // namespace tensorflow +REGISTER_OP("SessionStatus") + .Input("fetch_start_timestamp: double") + .Output("status: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Not for public usage. + +Returns messages from the current session as a serialized SessionStatusProto. + +This includes the current state of the compiler, along with any critical +logging or warning messages. + +fetch_start_timestamp: any messages earlier than this will be excluded from the +returned proto. +)doc"); + +} // end namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index cb61984799..76f1dd2a56 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -47,20 +47,16 @@ setup( # 4 - Beta # 5 - Production/Stable 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Artificial Intelligence', @@ -69,4 +65,5 @@ setup( 'Topic :: Software Development :: Libraries :: Python Modules', ], license='Apache 2.0', - keywords='tensorflow performance tpu',) + keywords='tensorflow performance tpu', +) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index a236c08991..7d2f6556fb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -517,7 +517,7 @@ class TPUEstimatorSpec( if self.eval_metrics is not None: host_calls['eval_metrics'] = self.eval_metrics if self.host_call is not None: - host_calls['host_call'] = wrap_hostcall_with_global_step(self.host_call) + host_calls['host_call'] = self.host_call host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) eval_metric_ops = None if self.eval_metrics is not None: @@ -660,7 +660,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): # for TPU computation waits for the infeed enqueue forever. Close the # Session to cancel the main thread Session.run execution. # - # However, sleep for 2 minutes before explicit closing to give some time + # We sleep for a few seconds before closing to give some time # for the TPU compilation error, if any, propagating, from TPU to CPU # host. Compilation errors should be reported by the main thread so that # the program can be interrupted and users can take action. Due to a race @@ -673,7 +673,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): # If the main session is still running, the infeed/outfeed errors are # legitimate, and should be logged. - if not self._finished: + if not self._finished and self._feed_error: logging.error('Feed error: %s', self._feed_error) logging.error('Closing session. A RuntimeError should follow.') session.close() @@ -731,10 +731,12 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): name='OutfeedController', target=self._run_outfeed, args=(session,)) def before_run(self, run_context): - if self._feed_error: - logging.warning('Feed error occurred, terminating session.') - run_context.request_stop() - return + self._feed_error = None + + # Wait for the cancellation timer to complete before continuing. + if self._session_cancel_timer: + self._session_cancel_timer.join() + self._session_cancel_timer = None iterations = run_context.session.run(self._iterations_per_loop_var) @@ -1351,19 +1353,21 @@ class _ModelFnWrapper(object): self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op - host_call_outfeed_ops = [] if isinstance(estimator_spec, TPUEstimatorSpec): captured_scaffold_fn.capture(estimator_spec.scaffold_fn) - if estimator_spec.host_call is not None: - host_call.record({ - 'host_call': wrap_hostcall_with_global_step( - estimator_spec.host_call)}) - host_call_outfeed_ops = host_call.create_enqueue_op() else: captured_scaffold_fn.capture(None) - with ops.control_dependencies([train_op] + host_call_outfeed_ops): - return array_ops.identity(loss) + # We must run train_op to update the variables prior to running the + # outfeed. + with ops.control_dependencies([train_op]): + host_call_outfeed_ops = [] + if (isinstance(estimator_spec, TPUEstimatorSpec) and + estimator_spec.host_call is not None): + host_call.record({'host_call': estimator_spec.host_call}) + host_call_outfeed_ops = host_call.create_enqueue_op() + with ops.control_dependencies(host_call_outfeed_ops): + return array_ops.identity(loss) return train_step, host_call, captured_scaffold_fn @@ -1708,38 +1712,6 @@ class _OutfeedHostCall(object): return ret -def wrap_hostcall_with_global_step(hostcall): - """Wrap the hostcall so that we update the global step upon every call.""" - if hostcall is None: - return None - host_fn, tensors = hostcall - - def global_step_host_fn(_global_step, *args, **kwargs): # pylint: disable=invalid-name - # Note that we don't have any ordering here, so the graph may see a - # global_step that's off by 1. - state_ops.assign( - training.get_global_step(), - math_ops.cast(_global_step[0], dtypes.int64)) - return host_fn(*args, **kwargs) - # Give the global step tensor a batch dimension. Reshape is not supported for - # int64, so we cast it to int32. - # TODO(jhseu): Remove the cast once int64 is supported. - global_step_tensor = array_ops.reshape( - math_ops.cast(training.get_global_step(), dtypes.int32), [1]) - if isinstance(tensors, dict): - outfeed_tensors = {'_global_step': global_step_tensor} - outfeed_tensors.update(tensors) - return global_step_host_fn, outfeed_tensors - else: - fn_args = util.fn_args(host_fn) - if len(tensors) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.host_call, length of tensors {} does not match ' - 'method args of the function, which takes {}.'.format( - len(tensors), len(fn_args))) - return global_step_host_fn, [global_step_tensor] + list(tensors) - - class _OutfeedHostCallHook(session_run_hook.SessionRunHook): """Hook to run host calls when use_tpu=False.""" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a495770135..d0c9a72af9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -193,6 +193,7 @@ CORE_PROTO_SRCS = [ "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", ] @@ -211,7 +212,6 @@ ADDITIONAL_CORE_PROTO_SRCS = [ "protobuf/named_tensor.proto", "protobuf/saved_model.proto", "protobuf/tensorflow_server.proto", - "util/event.proto", "util/test_log.proto", ] @@ -377,6 +377,17 @@ cc_library( ) cc_library( + name = "session_message", + srcs = ["util/session_message.cc"], + hdrs = ["util/session_message.h"], + deps = [ + ":framework", + ":lib", + ":protos_all_cc", + ], +) + +cc_library( name = "stacktrace_handler", srcs = ["platform/stacktrace_handler.cc"], hdrs = ["platform/stacktrace_handler.h"], @@ -454,6 +465,7 @@ tf_cuda_library( "framework/reader_interface.h", "framework/reader_op_kernel.h", "framework/register_types.h", + "framework/register_types_traits.h", "framework/resource_mgr.h", "framework/resource_op_kernel.h", "framework/selective_registration.h", @@ -786,6 +798,7 @@ tf_cuda_library( "graph/graph.h", "graph/graph_constructor.h", "graph/graph_def_builder.h", + "graph/graph_def_builder_util.h", "graph/node_builder.h", "graph/validate.h", "graph/while_context.h", @@ -1722,6 +1735,9 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ "platform/variant_coding.h", "graph/edgeset.h", "graph/graph.h", + "graph/graph_def_builder.h", + "graph/node_builder.h", + "graph/tensor_id.h", ] + glob( [ "example/**/*.h", @@ -1739,6 +1755,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ "framework/reader_base.*", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/session_message.*", "util/version_info.cc", ], ) + select({ @@ -1808,6 +1825,9 @@ tf_cuda_library( ] + [ "graph/edgeset.cc", "graph/graph.cc", + "graph/graph_def_builder.cc", + "graph/node_builder.cc", + "graph/tensor_id.cc", "graph/while_context.h", "graph/while_context.cc", ], @@ -1822,6 +1842,7 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/session_message.cc", "util/version_info.cc", ], ) + select({ @@ -1936,6 +1957,7 @@ GRAPH_HDRS = [ "graph/graph.h", "graph/graph_constructor.h", # NOTE(mrry): Don't include the .cc since it depends on common_runtime. "graph/graph_def_builder.h", + "graph/graph_def_builder_util.h", "graph/graph_partition.h", "graph/mkl_layout_pass.h", "graph/mkl_tfconversion_pass.h", @@ -1956,12 +1978,9 @@ tf_cuda_library( "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", - "graph/graph_def_builder.cc", "graph/graph_partition.cc", - "graph/node_builder.cc", "graph/optimizer_cse.cc", "graph/subgraph.cc", - "graph/tensor_id.cc", "graph/validate.cc", ], hdrs = GRAPH_HDRS, @@ -1990,6 +2009,7 @@ tf_cuda_library( "common_runtime/shape_refiner.h", "framework/versions.h", "graph/graph_constructor.cc", # Depends on common_runtime. + "graph/graph_def_builder_util.cc", # Depends on common_runtime. "public/session.h", "public/session_options.h", "public/version.h", diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 02c9cd5313..098024d219 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -193,7 +194,7 @@ class PlacerTest : public ::testing::Test { // Builds the given graph, and (if successful) indexes the node // names for use in placement, and later lookup. Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { - TF_RETURN_IF_ERROR(builder.ToGraph(out_graph)); + TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph)); nodes_by_name_.clear(); for (Node* node : out_graph->nodes()) { nodes_by_name_[node->name()] = node->id(); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 9d4a1eb8a1..878a1398c9 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1448,8 +1448,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const auto count = run_state->count; pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; - pss.collect_rpcs = - req.options().trace_level() == RunOptions::FULL_TRACE; + pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; pss.report_tensor_allocations_upon_oom = req.options().report_tensor_allocations_upon_oom(); @@ -1612,8 +1611,7 @@ Status MasterSession::DoRunWithLocalExecution( TRACEPRINTF("stepid %llu", step_id); pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; - pss.collect_rpcs = - req.options().trace_level() == RunOptions::FULL_TRACE; + pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; pss.report_tensor_allocations_upon_oom = req.options().report_tensor_allocations_upon_oom(); // Build the cost model every 'build_cost_model_every' steps after skipping an diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h index 3954af8ad8..fbddbda9e6 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -41,7 +41,7 @@ class GrpcWorker : public Worker { StatusCallback done); virtual void LoggingAsync(const LoggingRequest* request, - LoggingResponse* response, StatusCallback done); + LoggingResponse* response, StatusCallback done); WorkerEnv* env(); diff --git a/tensorflow/core/kernels/data/dataset.cc b/tensorflow/core/framework/dataset.cc index d18cb16018..4145ef7bc9 100644 --- a/tensorflow/core/kernels/data/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/dataset.h" + #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" @@ -265,10 +265,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, MakeDataset(ctx, input, another_input, output); } -Allocator* IteratorContext::allocator(AllocatorAttributes attrs) { - return params_.lib->device()->GetAllocator(attrs); -} - const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = "_DATASET_GRAPH_OUTPUT_NODE"; diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 96566c285a..6ab23d92a4 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -274,7 +274,7 @@ class IteratorContext { std::shared_ptr<const FunctionLibraryDefinition> function_library = nullptr; // The Allocator to be used to allocate the output of an iterator. - Allocator* allocator = nullptr; + std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; }; explicit IteratorContext(Params params) : params_(std::move(params)) {} @@ -301,7 +301,9 @@ class IteratorContext { void set_lib(FunctionLibraryRuntime* lib) { params_.lib = lib; } - Allocator* allocator(AllocatorAttributes attrs); + Allocator* allocator(AllocatorAttributes attrs) { + return params_.allocator_getter(attrs); + } private: Params params_; diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 0e2a410429..c9e8dd2217 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -177,10 +177,10 @@ class UnaryVariantOpRegistry { Op op_type_; StringPiece device_, typename_; }; - //friend declaration for operator== + // friend declaration for operator== // needed for clang template <typename Op> - friend bool operator==(const FuncTuple<Op> &l, const FuncTuple<Op> &r); + friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r); struct TupleHash { template <typename Op> std::size_t operator()( @@ -208,7 +208,8 @@ class UnaryVariantOpRegistry { binary_op_fns; // Find or insert a string into a persistent string storage - // container; return the StringPiece pointing to the permanent string location. + // container; return the StringPiece pointing to the permanent string + // location. static StringPiece GetPersistentStringPiece(const string& str) { const auto string_storage = PersistentStringStorage(); auto found = string_storage->find(str); diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index 0cdcdb6685..99ced0c0f5 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" @@ -81,7 +82,7 @@ TEST(AlgorithmTest, ReversePostOrder) { BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); Graph g(OpRegistry::Global()); - TF_ASSERT_OK(b.ToGraph(&g)); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); std::vector<Node*> order; // Test reverse post order: @@ -139,7 +140,7 @@ TEST(AlgorithmTest, ReversePostOrderStable) { BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t3")); Graph g(OpRegistry::Global()); - TF_ASSERT_OK(b.ToGraph(&g)); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); std::vector<Node*> order; // Test reverse post order generates expected ordering. diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index 33d2021f38..7a58347bd1 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -17,7 +17,6 @@ limitations under the License. #include <utility> -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" @@ -72,16 +71,6 @@ Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { return status_; } -Status GraphDefBuilder::ToGraph(Graph* graph) const { - if (status_.ok()) { - GraphDef graph_def; - graph_.ToGraphDef(&graph_def); - GraphConstructorOptions opts; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph)); - } - return status_; -} - string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { if (name_.empty()) return graph_->NewName(op); return name_; diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index a2c0c4d553..776a74c6d8 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -161,14 +161,6 @@ class GraphDefBuilder { // successful, and if so fill *graph_def. Status ToGraphDef(GraphDef* graph_def) const; - // Like ToGraphDef(), but converts to a Graph (using the default - // GraphConstructorOptions). - // TODO(josh11b): Make this faster; right now it converts - // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds - // edges from the source and to the sink node, resolves back edges - // by name), and makes sure the resulting graph is valid. - Status ToGraph(Graph* graph) const; - // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same diff --git a/tensorflow/core/graph/graph_def_builder_test.cc b/tensorflow/core/graph/graph_def_builder_test.cc index e928c81b45..be3c2be800 100644 --- a/tensorflow/core/graph/graph_def_builder_test.cc +++ b/tensorflow/core/graph/graph_def_builder_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +35,7 @@ TEST(GraphDefBuilderTest, Version) { // Check version when we convert to a Graph Graph graph(OpRegistry::Global()); - TF_EXPECT_OK(builder.ToGraph(&graph)); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, &graph)); ASSERT_EQ(graph.versions().producer(), TF_GRAPH_DEF_VERSION); ASSERT_EQ(graph.versions().min_consumer(), TF_GRAPH_DEF_VERSION_MIN_CONSUMER); diff --git a/tensorflow/core/graph/graph_def_builder_util.cc b/tensorflow/core/graph/graph_def_builder_util.cc new file mode 100644 index 0000000000..102c72185f --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder_util.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/graph/graph_def_builder_util.h" + +#include "tensorflow/core/graph/graph_constructor.h" + +namespace tensorflow { + +Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) { + GraphDef graph_def; + TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def)); + GraphConstructorOptions opts; + return ConvertGraphDefToGraph(opts, graph_def, graph); +} + +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder_util.h b/tensorflow/core/graph/graph_def_builder_util.h new file mode 100644 index 0000000000..4a157e5b71 --- /dev/null +++ b/tensorflow/core/graph/graph_def_builder_util.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ + +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; + +// Converts the `GraphDef` being built by `builder` to a `Graph` and +// stores it in `*graph`. +// TODO(josh11b): Make this faster; right now it converts +// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds +// edges from the source and to the sink node, resolves back edges +// by name), and makes sure the resulting graph is valid. +Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_UTIL_H_ diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 0e8a1cb26c..7d3be15299 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2211,7 +2211,7 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { return Status::OK(); } -#else // INTEL_MKL_ML +#else // INTEL_MKL_ML // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph @@ -2452,8 +2452,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, AddNRewrite}); - rinfo_.push_back({csinfo_.add, - mkl_op_registry::GetMklOpName(csinfo_.add), + rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), @@ -2509,8 +2508,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsDataType, AlwaysRewrite}); - rinfo_.push_back({csinfo_.relu, - mkl_op_registry::GetMklOpName(csinfo_.relu), + rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.relu_grad, mkl_op_registry::GetMklOpName(csinfo_.relu_grad), @@ -2537,7 +2535,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsDataType, AlwaysRewrite}); - // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index fde1ea1743..7219d9812f 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -361,7 +362,7 @@ static void BM_SubgraphHelper(int iters, int num_nodes, last_node = ops::SourceOp("In", b.opts().WithName(name)); } } - TF_CHECK_OK(b.ToGraph(&g)); + TF_CHECK_OK(GraphDefBuilderToGraph(b, &g)); } std::vector<string> fed; diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 2ac31ebf6a..3432de9dcd 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -140,6 +140,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", ], ) diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index db64e53026..edb0db65e9 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -52,21 +52,14 @@ bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) { return removed_input; } -// Remove duplicate control inputs. -void PruneControlInputs(NodeDef* node) { - std::unordered_set<string> inputs; - int pos = 0; - while (pos < node->input_size()) { - const string& input = node->input(pos); - if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { - VLOG(1) << "**** Removing duplicate control input: " << input - << " from node " << node->DebugString(); - node->mutable_input()->SwapElements(pos, node->input_size() - 1); - node->mutable_input()->RemoveLast(); - } else { - ++pos; - } +void DeleteNodes(const std::set<int>& nodes_to_delete, GraphDef* graph) { + int last = graph->node_size() - 1; + for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { + const int index = *it; + graph->mutable_node()->SwapElements(index, last); + last--; } + graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size()); } } // namespace @@ -75,6 +68,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) { if (!IsIdentity(node)) { return true; } + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; } @@ -397,22 +391,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx, void DependencyOptimizer::CleanControlInputs() { for (int i = 0; i < optimized_graph_->node_size(); ++i) { - PruneControlInputs(optimized_graph_->mutable_node(i)); - } -} - -void DependencyOptimizer::DeleteNodes(const std::set<int>& nodes_to_delete) { - int last = optimized_graph_->node_size() - 1; - for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { - const int index = *it; - optimized_graph_->mutable_node()->SwapElements(index, last); - last--; + DedupControlInputs(optimized_graph_->mutable_node(i)); } - optimized_graph_->mutable_node()->DeleteSubrange(last + 1, - nodes_to_delete.size()); - // Rebuild the NodeMap which was invalidated by the node swapping above. - node_map_.reset(new NodeMap(optimized_graph_)); - BuildNodeToIdx(); } Status DependencyOptimizer::OptimizeDependencies() { @@ -437,7 +417,9 @@ Status DependencyOptimizer::OptimizeDependencies() { if (fetch_nodes_known_) { VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of " << optimized_graph_->node_size() << " nodes."; - DeleteNodes(nodes_to_delete); + DeleteNodes(nodes_to_delete, optimized_graph_); + node_map_.reset(new NodeMap(optimized_graph_)); + BuildNodeToIdx(); } return Status::OK(); } @@ -576,7 +558,6 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, Status topo_sort_status; // Perform topological sort to prepare the graph for transitive reduction. topo_sort_status = TopologicalSort(optimized_graph_); - // Set up index-based graph datastructures to speed up analysis steps below. node_map_.reset(new NodeMap(optimized_graph_)); BuildNodeToIdx(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index 0f47528a04..61ed154793 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -52,8 +52,6 @@ class DependencyOptimizer : public GraphOptimizer { void CleanControlInputs(); // Builds a map from the &optimized_graph_->node(i) to i. void BuildNodeToIdx(); - // Removes the given set of nodes from the graph. - void DeleteNodes(const std::set<int>& nodes_to_delete); // Tries to optimize the node with the given index, possibly additional // optimizations by inserting nodes in nodes_to_simplify, and pruning nodes by // inserting them in nodes_to_delete. diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index 2d47ded156..b45ceb12a7 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { @@ -61,10 +62,19 @@ void GraphRewriter::ForwardInputs( const NodeDef& original_node, const std::unordered_set<const NodeDef*>& nodes_to_delete, NodeDef* new_node) { - ForwardInputsInternal(original_node, nodes_to_delete, new_node); + ForwardInputsInternal(original_node, nodes_to_delete, false, new_node); if (!new_node->name().empty()) { optimized_nodes_[new_node->name()] = new_node; } + // Reorder inputs such that control inputs come after regular inputs. + int pos = 0; + for (int i = 0; i < new_node->input_size(); ++i) { + if (!IsControlInput(new_node->input(i))) { + new_node->mutable_input()->SwapElements(pos, i); + ++pos; + } + } + DedupControlInputs(new_node); } bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { @@ -72,6 +82,10 @@ bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { control_dependency_drivers_.end(); } +bool GraphRewriter::FeedsMerge(const NodeDef& node) const { + return merge_feeders_.find(&node) != merge_feeders_.end(); +} + bool GraphRewriter::IsDrivenByControlDependency(const NodeDef& node) const { for (const auto& input : node.input()) { CHECK(!input.empty()); @@ -94,12 +108,27 @@ bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const { return ref_receivers_.find(&node) != ref_receivers_.end(); } +bool GraphRewriter::IsDrivenBySwitch(const NodeDef& node) const { + return switch_receivers_.find(&node) != switch_receivers_.end(); +} + +bool GraphRewriter::RemovalIncreasesEdgeCount(const NodeDef& node) const { + const int in_degree = node.input_size(); + auto itr = nodes_.find(node.name()); + if (itr == nodes_.end()) { + return true; + } + const int out_degree = itr->second->out_degree; + return in_degree * out_degree > in_degree + out_degree; +} + void GraphRewriter::RecordConnectivity( const NodeDef& node, const std::unordered_set<string>& function_names) { const bool is_function = function_names.find(node.op()) != function_names.end(); bool ref_receiver = false; + bool switch_receiver = false; for (const auto& input : node.input()) { int position = 0; string input_node_name = ParseNodeName(input, &position); @@ -107,8 +136,14 @@ void GraphRewriter::RecordConnectivity( if (itr == nodes_.end()) { continue; } - const NodeInfo* fanin_info = itr->second.get(); + + NodeInfo* fanin_info = itr->second.get(); const NodeDef* fanin = fanin_info->def; + if (IsMerge(node)) { + merge_feeders_.insert(fanin); + } + // Update out_degree of fanin. + ++fanin_info->out_degree; if (position < 0) { // This is a control edge control_dependency_drivers_.insert(fanin); @@ -120,7 +155,9 @@ void GraphRewriter::RecordConnectivity( if (is_function) { function_neighbors_.insert(fanin); } - + if (IsSwitch(*fanin)) { + switch_receiver = true; + } if (position < fanin_info->outputs.size() && IsRefType(fanin_info->outputs[position])) { ref_receiver = true; @@ -134,34 +171,41 @@ void GraphRewriter::RecordConnectivity( if (ref_receiver) { ref_receivers_.insert(&node); } + if (switch_receiver) { + switch_receivers_.insert(&node); + } } void GraphRewriter::ForwardInputsInternal( const NodeDef& node, const std::unordered_set<const NodeDef*>& nodes_to_delete, - NodeDef* new_node) { + bool add_as_control, NodeDef* new_node) { // To speed things up, use the optimized version of the node if // available. auto itr = optimized_nodes_.find(node.name()); if (itr != optimized_nodes_.end()) { for (const string& input : itr->second->input()) { - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; } return; } for (const auto& input : node.input()) { - string input_node_name = NodeName(input); + const string input_node_name = NodeName(input); auto itr = nodes_.find(input_node_name); if (itr == nodes_.end()) { // Invalid input, preserve it as is. - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; continue; } const NodeDef* input_node = itr->second->def; if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { - ForwardInputsInternal(*input_node, nodes_to_delete, new_node); + ForwardInputsInternal(*input_node, nodes_to_delete, + add_as_control || IsControlInput(input), new_node); } else { - *new_node->add_input() = input; + *new_node->add_input() = + add_as_control ? AsControlDependency(NodeName(input)) : input; } } } diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index 4b9c9feef8..3d48d628e2 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -58,15 +58,27 @@ class GraphRewriter { // Returns true if the node has input from a stateful op. bool ReceivesRefValue(const NodeDef& node) const; + // Returns true if the node is driven by a Switch node. + bool IsDrivenBySwitch(const NodeDef& node) const; + + // Returns true if the node feeds a Merge node. + bool FeedsMerge(const NodeDef& node) const; + + // Returns true if removal of this degree would increase edge count, i.e. if + // in-degree * out-degree > in-degree + out-degree or if the condition could + // not be verified. + bool RemovalIncreasesEdgeCount(const NodeDef& node) const; + private: void RecordConnectivity(const NodeDef& node, const std::unordered_set<string>& function_names); void ForwardInputsInternal( const NodeDef& original_node, const std::unordered_set<const NodeDef*>& nodes_to_delete, - NodeDef* new_node); + bool add_as_control, NodeDef* new_node); struct NodeInfo { + int out_degree = 0; const NodeDef* def; // These are filled in when the NodeInfo is built, but not that they @@ -80,6 +92,8 @@ class GraphRewriter { std::unordered_set<const NodeDef*> function_neighbors_; std::unordered_set<const NodeDef*> cross_device_receivers_; std::unordered_set<const NodeDef*> ref_receivers_; + std::unordered_set<const NodeDef*> switch_receivers_; + std::unordered_set<const NodeDef*> merge_feeders_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index c9bec7890e..01282401a3 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -26,10 +26,17 @@ limitations under the License. namespace tensorflow { namespace grappler { -bool IsTrivialOp(const NodeDef& node) { +bool IsTrivialOp(const NodeDef& node, const GraphRewriter& rewriter) { // Remove the stop gradient nodes since they serve no purpose once the graph // is built. Also remove Identity ops. - if (IsStopGradient(node) || IsIdentity(node)) { + if (IsStopGradient(node)) { + return true; + } + if (IsIdentity(node) && + !(rewriter.FeedsMerge(node) && + rewriter.IsDrivenByControlDependency(node)) && + !(rewriter.IsDrivenBySwitch(node) && + rewriter.DrivesControlDependency(node))) { return true; } if (IsAddN(node) && NumNonControlInputs(node) <= 1) { @@ -41,7 +48,7 @@ bool IsTrivialOp(const NodeDef& node) { Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* pruned_graph) { - std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve(); + const std::unordered_set<string>& nodes_to_preserve = item.NodesToPreserve(); // Prune all the nodes that won't be executed, ie all the nodes that aren't in // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume @@ -72,7 +79,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // Check if we can further prune the graph, by removing the trivial ops. std::unordered_set<const NodeDef*> nodes_to_delete; for (auto& node : runnable_item.graph.node()) { - if (!IsTrivialOp(node)) { + if (!IsTrivialOp(node, rewriter)) { continue; } @@ -95,8 +102,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // converting references to non-references. It is important to preserve // these non-references since the partitioner will avoid sending // non-references across partitions more than once. - if (!rewriter.DrivesControlDependency(node) && - !rewriter.IsDrivenByControlDependency(node) && + if (!rewriter.RemovalIncreasesEdgeCount(node) && !rewriter.IsConnectedToFunction(node) && !rewriter.IsDrivenByAnotherDevice(node) && !rewriter.ReceivesRefValue(node)) { @@ -112,13 +118,16 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } + const bool fetches_are_known = !item.fetch.empty(); for (auto& node : runnable_item.graph.node()) { - NodeDef* new_node = pruned_graph->add_node(); - *new_node = node; - new_node->clear_input(); - rewriter.ForwardInputs(node, nodes_to_delete, new_node); + if (!fetches_are_known || + nodes_to_delete.find(&node) == nodes_to_delete.end()) { + NodeDef* new_node = pruned_graph->add_node(); + *new_node = node; + new_node->clear_input(); + rewriter.ForwardInputs(node, nodes_to_delete, new_node); + } } - VLOG(1) << "Pruned " << nodes_to_delete.size() << " nodes from the graph. The graph now contains " << pruned_graph->node_size() << " nodes."; diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index ee722f311e..c39444299e 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -156,47 +156,42 @@ TEST_F(ModelPrunerTest, NoOpPruning) { const NodeDef& new_e = output.node(4); EXPECT_EQ(NodeName(e.name()), new_e.name()); - EXPECT_EQ(1, new_e.input_size()); - EXPECT_EQ(NodeName(d.name()), new_e.input(0)); - EXPECT_EQ(2, new_d.input_size()); - EXPECT_EQ(NodeName(b.name()), new_d.input(0)); - EXPECT_EQ(1, new_c.input_size()); - EXPECT_EQ(NodeName(b.name()), new_c.input(0)); + for (const auto& new_node : output.node()) { + if (new_node.name() != "a") { + EXPECT_EQ(1, new_node.input_size()); + EXPECT_EQ("a", new_node.input(0)); + } + } } -TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { - // Build a simple graph with a few trivially prunable ops. - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::Sqrt(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d}); +TEST_F(ModelPrunerTest, PreserveIdentities) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); + ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); + ops::Switch s(scope.WithOpName("switch"), v_in, v_ctrl); + // id0 is preserved because it is fed by a Switch and drives a + // control dependency. + Output id0 = ops::Identity(scope.WithOpName("id0"), s.output_true); + // id1 is preserved because it feeds a Merge. + Output id1 = ops::Identity( + scope.WithOpName("id1").WithControlDependencies(v_ctrl), s.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), id0); + Output id3 = + ops::Identity(scope.WithOpName("id3").WithControlDependencies(id0), id1); + auto merge = ops::Merge(scope.WithOpName("merge"), {id0, id1}); GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("id2"); + item.fetch.push_back("id3"); + item.fetch.push_back("merge"); ModelPruner pruner; GraphDef output; Status status = pruner.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - EXPECT_EQ(5, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_d = output.node(3); - EXPECT_EQ(NodeName(d.name()), new_d.name()); - const NodeDef& new_e = output.node(4); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - - EXPECT_EQ(2, new_e.input_size()); - EXPECT_EQ(NodeName(c.name()), new_e.input(0)); - EXPECT_EQ("^c", new_e.input(1)); + TF_EXPECT_OK(status); + EXPECT_EQ(item.graph.node_size(), output.node_size()); } TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) { @@ -239,54 +234,47 @@ TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) { EXPECT_EQ("b", new_e.input(0)); } -TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { +TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Sqrt(s.WithOpName("c"), {a}); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::Identity(s.WithOpName("e"), d); - Output f = ops::Sqrt(s.WithOpName("f"), {e}); + Output d = ops::Identity(s.WithOpName("d").WithControlDependencies(b), c); + Output e = ops::Identity(s.WithOpName("e").WithControlDependencies(c), d); + Output f = ops::Sqrt(s.WithOpName("f"), {d}); + Output g = ops::Sqrt(s.WithOpName("g"), {e}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Add a control dependency between b and d and another one between c and e. - // They should be properly forwarded. - EXPECT_EQ("d", item.graph.node(3).name()); - EXPECT_EQ("e", item.graph.node(4).name()); - *item.graph.mutable_node(3)->add_input() = "^b"; - *item.graph.mutable_node(4)->add_input() = "^c"; + item.fetch.push_back("f"); + item.fetch.push_back("g"); ModelPruner pruner; GraphDef output; Status status = pruner.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + LOG(INFO) << "After: " << output.DebugString(); - EXPECT_EQ(6, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_d = output.node(3); - EXPECT_EQ(NodeName(d.name()), new_d.name()); - const NodeDef& new_e = output.node(4); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - const NodeDef& new_f = output.node(5); - EXPECT_EQ(NodeName(f.name()), new_f.name()); - - EXPECT_EQ(1, new_f.input_size()); - EXPECT_EQ(NodeName(e.name()), new_f.input(0)); - EXPECT_EQ(2, new_e.input_size()); - EXPECT_EQ(NodeName(d.name()), new_e.input(0)); - EXPECT_EQ("^c", new_e.input(1)); - EXPECT_EQ(2, new_d.input_size()); - EXPECT_EQ(NodeName(c.name()), new_d.input(0)); - EXPECT_EQ("^b", new_d.input(1)); + EXPECT_EQ(5, output.node_size()); + for (const auto& new_node : output.node()) { + // "d" and "e" should be removed. + EXPECT_NE("d", new_node.name()); + EXPECT_NE("e", new_node.name()); + if (new_node.name() == "g") { + EXPECT_EQ(2, new_node.input_size()); + // The input from switch should be forwarded to id3. + EXPECT_EQ("c", new_node.input(0)); + EXPECT_EQ("^b", new_node.input(1)); + } + if (new_node.name() == "f") { + EXPECT_EQ(2, new_node.input_size()); + // The input from switch should be forwarded to id3. + EXPECT_EQ("c", new_node.input(0)); + EXPECT_EQ("^b", new_node.input(1)); + } + } } TEST_F(ModelPrunerTest, PruningPerservesFetch) { @@ -296,6 +284,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) { Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::Identity(s.WithOpName("d"), c); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 634577ed30..7bfaf36865 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -312,6 +312,20 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, } } +void DedupControlInputs(NodeDef* node) { + std::unordered_set<string> inputs; + int pos = 0; + while (pos < node->input_size()) { + const string& input = node->input(pos); + if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { + node->mutable_input()->SwapElements(pos, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + } else { + ++pos; + } + } +} + namespace { template <typename T> inline void STLSortAndRemoveDuplicates(T* v) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 8840c44d05..4ecb28f681 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -143,6 +143,9 @@ int NumNonControlInputs(const NodeDef& node); // Number of connected non-control outputs. int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); +// Removes redundant control inputs from node. +void DedupControlInputs(NodeDef* node); + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index ba4e6b1bae..eabce5b5ee 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -29,83 +29,84 @@ namespace { class UtilsTest : public ::testing::Test { protected: NodeDef CreateConcatOffsetNode() const { - const string gdef_ascii = R"EOF( -name: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/ConcatOffset" -op: "ConcatOffset" -input: "InceptionV3/Mixed_7c/Branch_1/concat_v2/axis" -input: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape" -input: "gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1" -attr { - key: "N" - value { - i: 2 - } -} - )EOF"; + const string gdef_ascii = + " name: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/" + "ConcatOffset'" + " op: 'ConcatOffset'" + " input: 'InceptionV3/Mixed_7c/Branch_1/concat_v2/axis'" + " input: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape'" + " input: " + " 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1'" + " attr {" + " key: 'N'" + " value {" + " i: 2" + " }" + " }"; NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; } NodeDef CreateDequeueNode() const { - const string gdef_ascii = R"EOF( -name: "Train/TrainInput/input_producer_Dequeue" -op: "QueueDequeueV2" -input: "Train/TrainInput/input_producer" -attr { - key: "component_types" - value { - list { - type: DT_INT32 - } - } -} -attr { - key: "timeout_ms" - value { - i: -1 - } -} - )EOF"; + const string gdef_ascii = + " name: 'Train/TrainInput/input_producer_Dequeue'" + " op: 'QueueDequeueV2'" + " input: 'Train/TrainInput/input_producer'" + " attr {" + " key: 'component_types'" + " value {" + " list {" + " type: DT_INT32" + " }" + " }" + " }" + " attr {" + " key: 'timeout_ms'" + " value {" + " i: -1" + " }" + " }"; + NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; } NodeDef CreateFusedBatchNormNode() const { - const string gdef_ascii = R"EOF( -name: "InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm" -op: "FusedBatchNorm" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/Const" -input: "InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1" -attr { - key: "T" - value { - type: DT_FLOAT - } -} -attr { - key: "data_format" - value { - s: "NHWC" - } -} -attr { - key: "epsilon" - value { - f: 0.001 - } -} -attr { - key: "is_training" - value { - b: true - } -} - )EOF"; + const string gdef_ascii = + " name: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" + " op: 'FusedBatchNorm'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const'" + " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: 'data_format'" + " value {" + " s: 'NHWC'" + " }" + " }" + " attr {" + " key: 'epsilon'" + " value {" + " f: 0.001" + " }" + " }" + " attr {" + " key: 'is_training'" + " value {" + " b: true" + " }" + " }"; + NodeDef node; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); return node; @@ -250,6 +251,49 @@ TEST_F(UtilsTest, GetTailOfChain) { EXPECT_EQ("noop", tail->name()); } +TEST_F(UtilsTest, DedupControlInputs) { + NodeDef foo; + foo.set_name("foo"); + foo.add_input("bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + + foo.set_input(0, "^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("^bar", foo.input(0)); + + foo.set_input(0, "bar"); + foo.add_input("bar"); + DedupControlInputs(&foo); + EXPECT_EQ(2, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + EXPECT_EQ("bar", foo.input(1)); + + foo.set_input(1, "^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + + foo.set_input(0, "^bar"); + foo.add_input("^bar"); + DedupControlInputs(&foo); + EXPECT_EQ(1, foo.input_size()); + EXPECT_EQ("^bar", foo.input(0)); + + foo.set_input(0, "bar"); + foo.add_input("gnu"); + foo.add_input("^bar"); + foo.add_input("^gnu"); + DedupControlInputs(&foo); + EXPECT_EQ(2, foo.input_size()); + EXPECT_EQ("bar", foo.input(0)); + EXPECT_EQ("gnu", foo.input(1)); +} + +TEST_F(UtilsTest, DeleteNodes) {} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/compare_and_bitpack_op.cc b/tensorflow/core/kernels/compare_and_bitpack_op.cc index 39e4f24ed5..224fe534e3 100644 --- a/tensorflow/core/kernels/compare_and_bitpack_op.cc +++ b/tensorflow/core/kernels/compare_and_bitpack_op.cc @@ -114,14 +114,13 @@ struct ComputeShard<T, for (int64 i = start; i < limit; ++i) { uint8* out = output.data() + i; const int64 block = *reinterpret_cast<const int64*>(input.data() + 8 * i); - *out = - ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) | - (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) | - (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) | - (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) | - (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) | - (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) | - (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL))))); + *out = ((((block & (1LL << (7 * 8))) >> (7 * 8 - 7))) | + (((block & (1LL << (6 * 8))) >> (6 * 8 - 6))) | + (((block & (1LL << (5 * 8))) >> (5 * 8 - 5))) | + (((block & (1LL << (4 * 8))) >> (4 * 8 - 4))) | + (((block & (1LL << (3 * 8))) >> (3 * 8 - 3))) | + (((block & (1LL << (2 * 8))) >> (2 * 8 - 2))) | + (((block & (1LL << 8)) >> (1 * 8 - 1))) | (((block & (1LL))))); } #else for (int64 i = start; i < limit; ++i) { diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index c4e21257ff..8e91baaa1c 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -44,9 +44,10 @@ tf_kernel_library( ], ) +# TODO(mrry): Remove this empty forwarding library. cc_library( name = "dataset", - srcs = ["dataset.cc"], + srcs = [], hdrs = ["dataset.h"], deps = [ "//tensorflow/core:core_cpu", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index fc3e291afb..d7d4ad5cf7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -160,6 +160,10 @@ class IteratorResource : public ResourceBase { params.runner = *(ctx->runner()); params.function_library = flib_def; params.lib = lib_; + DeviceBase* device = lib_->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); TF_RETURN_IF_ERROR(captured_iterator->Restore(&iter_ctx, reader)); @@ -605,6 +609,11 @@ class ToSingleElementOp : public AsyncOpKernel { params.env = ctx->env(); params.runner = *(ctx->runner()); params.lib = ctx->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); std::vector<Tensor> components; @@ -863,6 +872,10 @@ class IteratorGetNextOp : public AsyncOpKernel { }; params.runner = *(ctx->runner()); params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); OP_REQUIRES_OK_ASYNC( @@ -905,6 +918,10 @@ class IteratorGetNextSyncOp : public OpKernel { }; params.runner = *(ctx->runner()); params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; IteratorContext iter_ctx(std::move(params)); OP_REQUIRES_OK(ctx, diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index ef724f0a29..b539b00009 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -318,9 +318,9 @@ class MklAddNOp : public OpKernel { // if the shapes of two tensors are not same raise op error TensorShape src1_shape, src2_shape; src1_shape = input1_in_mkl_format ? src1_mkl_shape.GetTfShape() - : src1_tensor.shape(); + : src1_tensor.shape(); src2_shape = input2_in_mkl_format ? src2_mkl_shape.GetTfShape() - : src2_tensor.shape(); + : src2_tensor.shape(); if (!src1_shape.IsSameSize(src2_shape)) { ctx->SetStatus(errors::InvalidArgument( diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index cff1bd18a7..d545d34fdf 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -428,11 +428,8 @@ class MklAvgPoolingGradOp : public OpKernel { TensorFormat data_format_; }; // MklAvgPoolingGradOp - - #else - template <typename Device, typename T> class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { public: @@ -485,7 +482,7 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { } const int kOutputIndex = 0; AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, - output_tf_shape, output_mkl_shape); + output_tf_shape, output_mkl_shape); CHECK_NOTNULL(output_tensor); return; } @@ -702,12 +699,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { } }; // MklAvgPoolingGradOp - - - #endif // INTEL_MKL_ML - REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") .Device(DEVICE_CPU) .TypeConstraint<float>("T") diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index cbda12689f..2953426d58 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -41,8 +41,6 @@ limitations under the License. #include "tensorflow/core/util/mkl_util.h" - - #ifndef INTEL_MKL_ML #include "mkldnn.hpp" diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index acb0db57b3..5a8799ae93 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -296,9 +296,9 @@ class MklInputConversionOp : public OpKernel { if (tf_shapes_are_same) { auto input0_md = input_shape_0.GetMklLayout(); auto input1_md = input_shape_1.GetMklLayout(); - + // If both have the same shape and same format, pass them through - if ( input0_md.data.format == input1_md.data.format) { + if (input0_md.data.format == input1_md.data.format) { VLOG(1) << "MklInputConversionOp: No conversion needed, " << "copying MKL inputs with identical shapes to output"; @@ -306,9 +306,10 @@ class MklInputConversionOp : public OpKernel { ForwardMklTensorInToOut(context, 1, 1); return; } else { - VLOG(1) << "MklInputConversionOp: Shape is same, but format is different, " + VLOG(1) << "MklInputConversionOp: Shape is same, but format is " + "different, " << "need to convert to same format"; - + // Convert input0, and keep input1 unchanged // Create MklDnnShape for output mkl tensor based on input0 Tensor* tensor_out; @@ -324,7 +325,8 @@ class MklInputConversionOp : public OpKernel { // Create output Mkl tensor for index 0 AllocateOutputSetMklShape(context, 0, &tensor_out, - input_tensor_0.shape(), mkl_output_mkl_shape); + input_tensor_0.shape(), + mkl_output_mkl_shape); // Create MklDnnData object for input0 tesnsor auto cpu_engine = engine(engine::cpu, 0); @@ -333,15 +335,15 @@ class MklInputConversionOp : public OpKernel { // Create reorder from input0's layout to input1's layout std::vector<primitive> net; - CHECK_EQ(input.CheckReorderToOpMem(memory::primitive_desc( - input1_md, cpu_engine), - tensor_out, &net), - true); + CHECK_EQ(input.CheckReorderToOpMem( + memory::primitive_desc(input1_md, cpu_engine), + tensor_out, &net), + true); stream(stream::kind::eager).submit(net).wait(); // Input1 will be passed through ForwardMklTensorInToOut(context, 1, 1); - return; + return; } } diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 0be8355afa..51db3991e2 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -368,11 +368,8 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) { mkl_context.MklCleanup(); } - - #else // INTEL_MKL_ML - template <typename Device, typename T, algorithm alg_kind> class MklReluOpBase : public OpKernel { public: diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index da9ab01e8d..62e814ff77 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -39,8 +39,9 @@ class UnravelIndexOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& indices_tensor = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices_tensor.shape()) || - TensorShapeUtils::IsScalar(indices_tensor.shape()), + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(indices_tensor.shape()) || + TensorShapeUtils::IsScalar(indices_tensor.shape()), errors::InvalidArgument( "The indices can only be scalar or vector, got \"", indices_tensor.shape().DebugString(), "\"")); @@ -88,10 +89,11 @@ class UnravelIndexOp : public OpKernel { output = output.constant(indices_tensor.scalar<Tidx>()()); output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted; } else { - OP_REQUIRES_OK(ctx, ctx->allocate_output( - 0, TensorShape({dims_tensor.NumElements(), - indices_tensor.NumElements()}), - &output_tensor)); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, + TensorShape({dims_tensor.NumElements(), + indices_tensor.NumElements()}), + &output_tensor)); auto output = output_tensor->matrix<Tidx>(); diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 2580eaf987..8db4373fdc 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -46690,6 +46690,49 @@ op { } } op { + name: "Roll" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "shift" + type_attr: "Tshift" + } + input_arg { + name: "axis" + type_attr: "Taxis" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshift" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Taxis" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "Round" input_arg { name: "x" @@ -65318,6 +65361,34 @@ op { } } op { + name: "UnravelIndex" + input_arg { + name: "indices" + type_attr: "Tidx" + } + input_arg { + name: "dims" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "Tidx" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "UnsortedSegmentMax" input_arg { name: "data" diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 8df126735b..2e96211fdc 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -22135,6 +22135,49 @@ op { } } op { + name: "Roll" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "shift" + type_attr: "Tshift" + } + input_arg { + name: "axis" + type_attr: "Taxis" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tshift" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Taxis" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "Round" input_arg { name: "x" @@ -30888,6 +30931,34 @@ op { } } op { + name: "UnravelIndex" + input_arg { + name: "indices" + type_attr: "Tidx" + } + input_arg { + name: "dims" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "Tidx" + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "UnsortedSegmentMax" input_arg { name: "data" diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc index 7caf9d4db6..b570658158 100644 --- a/tensorflow/core/platform/cpu_feature_guard.cc +++ b/tensorflow/core/platform/cpu_feature_guard.cc @@ -106,7 +106,7 @@ void InfoAboutUnusedCPUFeatures() { CheckIfFeatureUnused(CPUFeature::AVX2, "AVX2", missing_instructions); #endif // __AVX2__ -#else // if defined(_MSC_VER) && !defined(__clang__) +#else // if defined(_MSC_VER) && !defined(__clang__) #ifndef __SSE__ CheckIfFeatureUnused(CPUFeature::SSE, "SSE", missing_instructions); diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index 4862fd85be..301fcb9dbf 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -22,6 +22,7 @@ limitations under the License. #include <aws/core/Aws.h> #include <aws/core/config/AWSProfileConfigLoader.h> #include <aws/core/utils/FileSystemUtils.h> +#include <aws/core/utils/StringUtils.h> #include <aws/core/utils/logging/AWSLogging.h> #include <aws/core/utils/logging/LogSystemInterface.h> #include <aws/core/utils/StringUtils.h> @@ -129,8 +130,7 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { return cfg; }; - -void ShutdownClient(Aws::S3::S3Client *s3_client) { +void ShutdownClient(Aws::S3::S3Client* s3_client) { if (s3_client != nullptr) { delete s3_client; Aws::SDKOptions options; @@ -167,7 +167,7 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket, class S3RandomAccessFile : public RandomAccessFile { public: S3RandomAccessFile(const string& bucket, const string& object, - std::shared_ptr<Aws::S3::S3Client> s3_client) + std::shared_ptr<Aws::S3::S3Client> s3_client) : bucket_(bucket), object_(object), s3_client_(s3_client) {} Status Read(uint64 offset, size_t n, StringPiece* result, @@ -203,7 +203,7 @@ class S3RandomAccessFile : public RandomAccessFile { class S3WritableFile : public WritableFile { public: S3WritableFile(const string& bucket, const string& object, - std::shared_ptr<Aws::S3::S3Client> s3_client) + std::shared_ptr<Aws::S3::S3Client> s3_client) : bucket_(bucket), object_(object), s3_client_(s3_client), @@ -285,8 +285,8 @@ class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { } // namespace -S3FileSystem::S3FileSystem() : - s3_client_(nullptr, ShutdownClient), client_lock_() {} +S3FileSystem::S3FileSystem() + : s3_client_(nullptr, ShutdownClient), client_lock_() {} S3FileSystem::~S3FileSystem() {} @@ -408,7 +408,8 @@ Status S3FileSystem::GetChildren(const string& dir, Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -481,7 +482,8 @@ Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) { .WithMaxKeys(1); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { if (listObjectsOutcome.GetResult().GetContents().size() > 0) { stats->length = 0; @@ -503,7 +505,7 @@ Status S3FileSystem::DeleteFile(const string& fname) { deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); auto deleteObjectOutcome = - this->GetS3Client()->DeleteObject(deleteObjectRequest); + this->GetS3Client()->DeleteObject(deleteObjectRequest); if (!deleteObjectOutcome.IsSuccess()) { string error = strings::StrCat( deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -550,7 +552,8 @@ Status S3FileSystem::DeleteDir(const string& dirname) { .WithMaxKeys(2); listObjectsRequest.SetResponseStreamFactory( []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); }); - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (listObjectsOutcome.IsSuccess()) { auto contents = listObjectsOutcome.GetResult().GetContents(); if (contents.size() > 1 || @@ -602,7 +605,8 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { Aws::S3::Model::ListObjectsResult listObjectsResult; do { - auto listObjectsOutcome = this->GetS3Client()->ListObjects(listObjectsRequest); + auto listObjectsOutcome = + this->GetS3Client()->ListObjects(listObjectsRequest); if (!listObjectsOutcome.IsSuccess()) { string error = strings::StrCat( listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ", @@ -615,14 +619,15 @@ Status S3FileSystem::RenameFile(const string& src, const string& target) { Aws::String src_key = object.GetKey(); Aws::String target_key = src_key; target_key.replace(0, src_object.length(), target_object.c_str()); - Aws::String source = Aws::String(src_bucket.c_str()) + "/" - + Aws::Utils::StringUtils::URLEncode(src_key.c_str()); + Aws::String source = Aws::String(src_bucket.c_str()) + "/" + + Aws::Utils::StringUtils::URLEncode(src_key.c_str()); copyObjectRequest.SetBucket(target_bucket.c_str()); copyObjectRequest.SetKey(target_key); copyObjectRequest.SetCopySource(source); - auto copyObjectOutcome = this->GetS3Client()->CopyObject(copyObjectRequest); + auto copyObjectOutcome = + this->GetS3Client()->CopyObject(copyObjectRequest); if (!copyObjectOutcome.IsSuccess()) { string error = strings::StrCat( copyObjectOutcome.GetError().GetExceptionName().c_str(), ": ", diff --git a/tensorflow/core/platform/s3/s3_file_system.h b/tensorflow/core/platform/s3/s3_file_system.h index 8177e48dba..31264be621 100644 --- a/tensorflow/core/platform/s3/s3_file_system.h +++ b/tensorflow/core/platform/s3/s3_file_system.h @@ -55,6 +55,7 @@ class S3FileSystem : public FileSystem { Status GetFileSize(const string& fname, uint64* size) override; Status RenameFile(const string& src, const string& target) override; + private: // Returns the member S3 client, initializing as-needed. // When the client tries to access the object in S3, e.g., diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto index 5c3799c132..65d2c5a09c 100644 --- a/tensorflow/core/util/event.proto +++ b/tensorflow/core/util/event.proto @@ -80,3 +80,8 @@ message TaggedRunMetadata { // deserialization. bytes run_metadata = 2; } + +// For communicating live events back to a coordinator +message SessionStatus { + repeated Event event = 1; +} diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 4467373c00..db4c5c35e3 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -749,7 +749,6 @@ inline void GetMklInputList(OpKernelContext* ctext, StringPiece name, ctext->input_list(name, input_tensors); } - #ifdef INTEL_MKL_ML inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name, diff --git a/tensorflow/core/util/session_message.cc b/tensorflow/core/util/session_message.cc new file mode 100644 index 0000000000..28a6517a1a --- /dev/null +++ b/tensorflow/core/util/session_message.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/session_message.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/event.pb.h" + +static const int kMaxLogEvents = 1000; + +namespace tensorflow { + +SessionLogger::SessionLogger() : status_(new SessionStatus) {} + +SessionLogger::~SessionLogger() {} + +string SessionLogger::DebugString() { return "SessionLogger"; } + +void SessionLogger::Log(StringPiece message) { + mutex_lock lock(mu_); + + Event* event = status_->add_event(); + event->set_wall_time(Env::Default()->NowMicros()); + event->set_step(0); + LogMessage* log = event->mutable_log_message(); + log->set_message(message.ToString()); + log->set_level(LogMessage::INFO); + + // Clip log events by 10% if we overflow + if (status_->event_size() > kMaxLogEvents) { + auto events = status_->mutable_event(); + events->DeleteSubrange(0, kMaxLogEvents / 10); + } +} + +SessionLogger* GetSessionLogger(ResourceMgr* rm) { + SessionLogger* logger; + + std::function<Status(SessionLogger**)> status_creator = + [](SessionLogger** result) { + *result = new SessionLogger(); + return Status::OK(); + }; + + if (!rm->LookupOrCreate<SessionLogger>("session", "status", &logger, + status_creator) + .ok()) { + return nullptr; + } + + return logger; +} + +void LogSessionMessage(ResourceMgr* rm, StringPiece message) { + return GetSessionLogger(rm)->Log(message); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/session_message.h b/tensorflow/core/util/session_message.h new file mode 100644 index 0000000000..c0f3d78b46 --- /dev/null +++ b/tensorflow/core/util/session_message.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ +#define TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class ResourceMgr; +class SessionStatus; + +class SessionLogger : public ResourceBase { + public: + SessionLogger(); + ~SessionLogger(); + + void Log(StringPiece message); + string DebugString() override; + + const SessionStatus& status() { return *status_; } + + private: + std::unique_ptr<SessionStatus> status_; + mutex mu_; +}; + +// Return a SessionLogger instance for the current session. If the logger +// will be used across multiple computations, you must explicitly acquire +// and release references using Ref()/Unref(). +// +// Returns nullptr if a logger cannot be created. +SessionLogger* GetSessionLogger(ResourceMgr* rm); + +// Attach `message` to the logger for the current session. +void LogSessionMessage(ResourceMgr* rm, StringPiece message); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_SESSION_MESSAGE_H diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java index bc0c738e53..068c7b0d94 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java @@ -82,7 +82,7 @@ public class LegacyCameraConnectionFragment extends Fragment { try { Camera.Parameters parameters = camera.getParameters(); List<String> focusModes = parameters.getSupportedFocusModes(); - if (focusModes != null + if (focusModes != null && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); } diff --git a/tensorflow/examples/label_image/label_image.py b/tensorflow/examples/label_image/label_image.py index 1c1bd57d71..fe5e0fc684 100644 --- a/tensorflow/examples/label_image/label_image.py +++ b/tensorflow/examples/label_image/label_image.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function import argparse -import sys import numpy as np import tensorflow as tf diff --git a/tensorflow/examples/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py index f1a7e1c4af..fa148ae3e6 100644 --- a/tensorflow/examples/tutorials/mnist/input_data.py +++ b/tensorflow/examples/tutorials/mnist/input_data.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=unused-import import gzip import os import tempfile @@ -27,3 +28,4 @@ from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets +# pylint: enable=unused-import diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index cb47651d7b..a7290ff117 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -629,6 +629,77 @@ func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, v return scope.AddOperation(opspec) } +// MapPeekAttr is an optional argument to MapPeek. +type MapPeekAttr func(optionalAttr) + +// MapPeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekCapacity(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// MapPeekMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func MapPeekMemoryLimit(value int64) MapPeekAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// MapPeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapPeekContainer(value string) MapPeekAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MapPeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapPeekSharedName(value string) MapPeekAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op peeks at the values at the specified key. If the +// +// underlying container does not contain this key +// this op will block until it does. +func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MapPeek", + Input: []tf.Input{ + key, indices, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapPeek", err) + return + } + return values +} + // Returns (x - y)(x - y) element-wise. // // *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting @@ -4509,6 +4580,68 @@ func CriticalSectionOp(scope *Scope, optional ...CriticalSectionOpAttr) (resourc return op.Output(0) } +// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. +type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["min"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["max"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxArgs operation. +// +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +// +// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: +// `gradients * (inputs >= min && inputs <= max)`. +func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxArgsGradient", + Input: []tf.Input{ + gradients, inputs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AvgPool3DAttr is an optional argument to AvgPool3D. type AvgPool3DAttr func(optionalAttr) @@ -20680,68 +20813,6 @@ func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp return op.Output(0) } -// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. -type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["max"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxArgs operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. -// -// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: -// `gradients * (inputs >= min && inputs <= max)`. -func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgsGradient", - Input: []tf.Input{ - gradients, inputs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // BatchToSpace for 4-D tensors of type T. // // This is a legacy version of the more general BatchToSpaceND. @@ -22254,6 +22325,76 @@ func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } +// Forwards the value of an available tensor from `inputs` to `output`. +// +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. +// +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. +// +// Arguments: +// inputs: The input tensors, exactly one of which will become available. +// +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Merge", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) + +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. +// +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { + return func(m optionalAttr) { + m["cancel_pending_enqueues"] = value + } +} + +// Closes the given queue. +// +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. +// +// Arguments: +// handle: The handle to a queue. +// +// Returns the created operation. +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QueueCloseV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Computes inverse hyperbolic tangent of x element-wise. func Atanh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -24203,147 +24344,6 @@ func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output return scope.AddOperation(opspec) } -// MapPeekAttr is an optional argument to MapPeek. -type MapPeekAttr func(optionalAttr) - -// MapPeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekCapacity(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// MapPeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapPeekMemoryLimit(value int64) MapPeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// MapPeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapPeekContainer(value string) MapPeekAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MapPeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapPeekSharedName(value string) MapPeekAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op peeks at the values at the specified key. If the -// -// underlying container does not contain this key -// this op will block until it does. -func MapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapPeekAttr) (values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MapPeek", - Input: []tf.Input{ - key, indices, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapPeek", err) - return - } - return values -} - -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) - -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. -// -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { - return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value - } -} - -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. -// -// Arguments: -// handle: The handle to a queue. -// -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QueueCloseV2", - Input: []tf.Input{ - handle, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. -// -// Arguments: -// inputs: The input tensors, exactly one of which will become available. -// -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Merge", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - // MapUnstageAttr is an optional argument to MapUnstage. type MapUnstageAttr func(optionalAttr) diff --git a/tensorflow/python/client/session_benchmark.py b/tensorflow/python/client/session_benchmark.py index 06e9a09926..da74855193 100644 --- a/tensorflow/python/client/session_benchmark.py +++ b/tensorflow/python/client/session_benchmark.py @@ -22,7 +22,7 @@ import time import numpy as np -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index f12c005511..490572254b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -31,7 +31,6 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -48,7 +47,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import from tensorflow.python.ops import state_ops diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index c1e9869d97..01867fc69d 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -40,6 +40,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -117,8 +118,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self): grad_debugger = debug_gradients.GradientsDebugger() grad_debugger.identify_gradient(self.w) - with self.assertRaisesRegexp( - ValueError, "The graph already contains an op named .*"): + with self.assertRaisesRegexp(ValueError, + "The graph already contains an op named .*"): grad_debugger.identify_gradient(self.w) def testIdentifyGradientWorksOnMultipleLosses(self): @@ -144,10 +145,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose(5.0 ** 2, self.sess.run(z1)) - self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(5.0**2, self.sess.run(z1)) + self.assertAllClose(5.0**0.5, self.sess.run(z2)) self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) - self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy)) def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self): grad_debugger_1 = debug_gradients.GradientsDebugger() @@ -259,8 +260,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.sess.run(variables.global_variables_initializer()) self.assertAllClose(3.0, self.sess.run(u_grad)) self.assertAllClose(2.0, self.sess.run(v_grad)) - self.assertAllClose( - 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + self.assertAllClose(3.0, self.sess.run( + grad_debugger.gradient_tensor("u:0"))) def testWatchGradientsWorksOnMultipleTensors(self): y = math_ops.add(self.w, -1.0, name="y") @@ -277,10 +278,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose( - 1.0, self.sess.run(grad_debugger.gradient_tensor("w:0"))) - self.assertAllClose( - 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + self.assertAllClose(1.0, self.sess.run( + grad_debugger.gradient_tensor("w:0"))) + self.assertAllClose(3.0, self.sess.run( + grad_debugger.gradient_tensor("u:0"))) def testWatchGradientsByXTensorsWorks(self): y = math_ops.add(self.w, -1.0, name="foo/y") @@ -290,8 +291,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): # But we can still get the gradient tensors by using # watch_gradients_by_x_tensors(). grad_debugger = debug_gradients.GradientsDebugger() - with grad_debugger.watch_gradients_by_tensors( - self.sess.graph, [self.w, self.u, y]): + with grad_debugger.watch_gradients_by_tensors(self.sess.graph, + [self.w, self.u, y]): gradient_descent.GradientDescentOptimizer(0.1).minimize(z) self.assertEqual(3, len(grad_debugger.gradient_tensors())) @@ -324,18 +325,18 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): self.assertIsNot(dz1_dy, dz2_dy) self.sess.run(variables.global_variables_initializer()) - self.assertAllClose(5.0 ** 2, self.sess.run(z1)) - self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(5.0**2, self.sess.run(z1)) + self.assertAllClose(5.0**0.5, self.sess.run(z2)) self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) - self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy)) def testGradientsValuesFromDumpWorks(self): y = math_ops.add(self.w, -1.0, name="y") z = math_ops.square(y, name="z") grad_debugger = debug_gradients.GradientsDebugger() - with grad_debugger.watch_gradients_by_tensors( - self.sess.graph, [self.w, self.u, y]): + with grad_debugger.watch_gradients_by_tensors(self.sess.graph, + [self.w, self.u, y]): train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z) self.sess.run(variables.global_variables_initializer()) @@ -343,10 +344,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase): run_options = config_pb2.RunOptions(output_partition_graphs=True) dump_dir = tempfile.mkdtemp() debug_url = "file://" + dump_dir - debug_utils.watch_graph( - run_options, - self.sess.graph, - debug_urls=debug_url) + debug_utils.watch_graph(run_options, self.sess.graph, debug_urls=debug_url) run_metadata = config_pb2.RunMetadata() self.assertAllClose(2.0, self.sess.run(self.u)) self.sess.run(train_op, options=run_options, run_metadata=run_metadata) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index e446b3e03a..0c636a8da1 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -27,7 +27,6 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib -from tensorflow.python.util import compat from tensorflow.python.util import compat_internal diff --git a/tensorflow/python/estimator/warm_starting_util.py b/tensorflow/python/estimator/warm_starting_util.py index 48110ef57f..57db968d56 100644 --- a/tensorflow/python/estimator/warm_starting_util.py +++ b/tensorflow/python/estimator/warm_starting_util.py @@ -117,21 +117,13 @@ class WarmStartSettings( ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000") ``` - Warm-start only the embeddings (input layer) and their accumulator variables: + Warm-start only the embeddings (input layer): ``` ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", vars_to_warm_start=".*input_layer.*") ``` - Warm-start everything except the optimizer accumulator variables - (DNN defaults to Adagrad): - - ``` - ws = WarmStartSettings(ckpt_to_initialize_from="/tmp", - vars_to_warm_start="^(?!.*(Adagrad))") - ``` - Warm-start all weights but the embedding parameters corresponding to `sc_vocab_file` have a different vocab from the one used in the current model: @@ -423,6 +415,8 @@ def _warm_start(warm_start_settings): # Both warm_start_settings.vars_to_warm_start = '.*' and # warm_start_settings.vars_to_warm_start = None will match everything here. for v in ops.get_collection( + # TODO(eddz): Allow for different collections here (to support + # warm-starting accumulators). ops.GraphKeys.TRAINABLE_VARIABLES, scope=warm_start_settings.vars_to_warm_start): if not isinstance(v, list): diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py index c997ead829..1f2aa264c1 100644 --- a/tensorflow/python/framework/load_library.py +++ b/tensorflow/python/framework/load_library.py @@ -21,10 +21,10 @@ from __future__ import print_function import hashlib import imp import sys -import threading +import threading # pylint: disable=unused-import from tensorflow.core.framework import op_def_pb2 -from tensorflow.core.lib.core import error_codes_pb2 +from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import from tensorflow.python import pywrap_tensorflow as py_tf from tensorflow.python.framework import errors_impl from tensorflow.python.util import compat diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 15e8f5a38d..bfdd98819e 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -49,7 +49,7 @@ from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import tape +from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1146,13 +1146,19 @@ class TensorFlowTestCase(googletest.TestCase): del path[-1] # a and b are ndarray like objects else: - self._assertArrayLikeAllClose( - a, - b, - rtol=rtol, - atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + try: + self._assertArrayLikeAllClose( + a, + b, + rtol=rtol, + atol=atol, + msg="Mismatched value: a%s is different from b%s." % (path_str, + path_str)) + except TypeError as e: + msg = "Error: a%s has %s, but b%s has %s" % ( + path_str, type(a), path_str, type(b)) + e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:]) + raise def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6): """Asserts that two structures of numpy arrays, have near values. diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 0c8d04ff29..8079cb307b 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -140,6 +140,7 @@ static GCluster TF_NewCluster(bool allow_soft_placement, timeout_s, num_cpu_cores, num_gpus); cluster_->DisableDetailedStats(disable_detailed_stats); cluster_->AllowSoftPlacement(allow_soft_placement); + cluster_->SetNumWarmupSteps(10); tensorflow::Status status = cluster_->Provision(); tensorflow::Set_TF_Status_from_Status(out_status, status); return GCluster(cluster_); diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index ee7a5621e0..586130f806 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import time +import unittest import numpy as np @@ -1182,6 +1183,29 @@ class UnravelIndexTest(test_util.TensorFlowTestCase): self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) +class UnravelIndexTest(test_util.TensorFlowTestCase): + + # TODO(b/73086570): Reenable test. + @unittest.skip("Test does not pass internally.") + def testUnravelIndex(self): + with self.test_session(): + for dtype in [dtypes.int32, dtypes.int64]: + indices_1 = constant_op.constant(1621, dtype=dtype) + dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) + + indices_2 = constant_op.constant([1621], dtype=dtype) + dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_2 = array_ops.unravel_index(indices_2, dims_2) + self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]]) + + indices_3 = constant_op.constant([22, 41, 37], dtype=dtype) + dims_3 = constant_op.constant([7, 6], dtype=dtype) + out_3 = array_ops.unravel_index(indices_3, dims_3) + self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + + class GuaranteeConstOpTest(test_util.TensorFlowTestCase): def testSimple(self): diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index a5fd3bc334..127bc6bb20 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -495,9 +495,9 @@ class ConcatOpTest(test.TestCase): p = [] shape = np.array([7, 13]) if test.is_gpu_available(): - num_tensors = 10000 + num_tensors = 5000 else: - num_tensors = 1000 + num_tensors = 500 for i in np.arange(num_tensors): input_shape = shape placeholder = array_ops.placeholder(dtypes.float32, shape=input_shape) diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 576bb68ba4..16e56349c4 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -465,9 +465,8 @@ class ZerosLikeTest(test.TestCase): def testZerosLikeGPU(self): for dtype in [ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, - dtypes_lib.int32, dtypes_lib.int64, - dtypes_lib.complex64, dtypes_lib.complex128, - dtypes_lib.bool + dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.complex64, + dtypes_lib.complex128, dtypes_lib.bool ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True) diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py index 7d0bc54b69..1a65c3f429 100644 --- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py +++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python.client import device_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index c5446326ba..edfb20d6a2 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -24,7 +24,7 @@ import time import numpy as np -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import layers from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op @@ -520,7 +520,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 2], padding="VALID") - # TODO this currently fails. + # TODO(yzhwang): this currently fails. # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], # filter_in_sizes=[2, 2, 1, 1], # strides=[4, 4], padding="SAME", diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py index c67c26b7be..35f8f76991 100644 --- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py index 0c7025f54e..122a9ed469 100644 --- a/tensorflow/python/kernel_tests/decode_raw_op_test.py +++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -import sys from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 748135440e..ce73e7ad3e 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import random -import re import time import numpy as np diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index f1fbe1a745..197dbf44af 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -953,8 +953,8 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): # Compute the expected loss 'manually'. total = np.zeros((batch_size,)) for b in range(batch_size): - for i in range(dims-1): - for j in range(i+1, dims): + for i in range(dims - 1): + for j in range(i + 1, dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() y = self._labels[b, i].item() - self._labels[b, j].item() diff = (x - y) @@ -1059,8 +1059,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): [[4, 8, 12], [1, 2, 3], [4, 5, 6]], [[8, 1, 3], [7, 8, 9], [10, 11, 12]], ]) - self._test_valid_weights( - labels, predictions, expected_loss=137.5) + self._test_valid_weights(labels, predictions, expected_loss=137.5) def test3dWeightedScalar(self): labels = np.array([ @@ -1073,8 +1072,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) weight = 3.0 self._test_valid_weights( - labels, predictions, expected_loss=weight * 137.5, - weights=weight) + labels, predictions, expected_loss=weight * 137.5, weights=weight) def _test_invalid_weights( self, labels, predictions, weights=1.0): @@ -1124,7 +1122,9 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) self._test_valid_weights( # TODO(ptucker): This doesn't look right. - labels, predictions, expected_loss=9 * 137.5, + labels, + predictions, + expected_loss=9 * 137.5, weights=np.ones((2, 3, 3))) def testLossWithAllZeroBatchSpecificWeights(self): diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py index 3044b21aa4..b8200ac0cb 100644 --- a/tensorflow/python/kernel_tests/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/manip_ops_test.py @@ -17,25 +17,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util -from tensorflow.python.ops import manip_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import manip_ops from tensorflow.python.platform import test as test_lib -import numpy as np - # pylint: disable=g-import-not-at-top try: from distutils.version import StrictVersion as Version # numpy.roll for multiple shifts was introduced in numpy version 1.12.0 - NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version('1.12.0') + NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") except ImportError: NP_ROLL_CAN_MULTISHIFT = False # pylint: enable=g-import-not-at-top + class RollTest(test_util.TensorFlowTestCase): + def _testRoll(self, np_input, shift, axis): expected_roll = np.roll(np_input, shift, axis) with self.test_session(): @@ -62,10 +64,12 @@ class RollTest(test_util.TensorFlowTestCase): for t in [np.int32, np.int64]: self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0) if NP_ROLL_CAN_MULTISHIFT: - self._testAll(np.random.randint(-100, 100, (4, 4, 3)).astype(t), - [1, -2, 3], [0, 1, 2]) - self._testAll(np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), - [0, 1, -2], [1, 2, 3]) + self._testAll( + np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3], + [0, 1, 2]) + self._testAll( + np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2], + [1, 2, 3]) def testFloatTypes(self): for t in [np.float32, np.float64]: @@ -84,7 +88,6 @@ class RollTest(test_util.TensorFlowTestCase): x = np.random.rand(3, 2, 1, 1).astype(t) self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2]) - def testRollInputMustVectorHigherRaises(self): tensor = 7 shift = 1 @@ -95,8 +98,7 @@ class RollTest(test_util.TensorFlowTestCase): manip_ops.roll(tensor, shift, axis).eval() def testRollAxisMustBeScalarOrVectorRaises(self): - tensor = [[1, 2], - [3, 4]] + tensor = [[1, 2], [3, 4]] shift = 1 axis = [[0, 1]] with self.test_session(): @@ -105,8 +107,7 @@ class RollTest(test_util.TensorFlowTestCase): manip_ops.roll(tensor, shift, axis).eval() def testRollShiftMustBeScalarOrVectorRaises(self): - tensor = [[1, 2], - [3, 4]] + tensor = [[1, 2], [3, 4]] shift = [[0, 1]] axis = 1 with self.test_session(): @@ -115,8 +116,7 @@ class RollTest(test_util.TensorFlowTestCase): manip_ops.roll(tensor, shift, axis).eval() def testRollShiftAndAxisMustBeSameSizeRaises(self): - tensor = [[1, 2], - [3, 4]] + tensor = [[1, 2], [3, 4]] shift = [1] axis = [0, 1] with self.test_session(): @@ -133,5 +133,6 @@ class RollTest(test_util.TensorFlowTestCase): "is out of range"): manip_ops.roll(tensor, shift, axis).eval() + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py index c4e16ff628..b7a79f239c 100644 --- a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py +++ b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import random -import re import time import numpy as np diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index a86b65affe..daa42938e6 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -23,7 +23,7 @@ import timeit import numpy as np -from six.moves import xrange +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index bb3f6970e4..ac08f2aec0 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - import numpy as np from tensorflow.python.framework import constant_op diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 20445c78a2..220ef1754d 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -20,9 +20,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_candidate_sampling_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import math_ops # pylint: disable=unused-import from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 9ae9a71e4b..3a6fdaafb9 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -55,7 +55,6 @@ import collections import functools import six -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import control_flow_pb2 diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 230b6c5946..9f06c0ee1f 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops +from tensorflow.python.ops import check_ops # pylint: disable=unused-import from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 22636fdbb3..14a38f25d1 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1671,8 +1671,8 @@ def non_max_suppression(boxes, # pylint: enable=protected-access -_rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115], - [0.587, -0.27455667, -0.52273617], +_rgb_to_yiq_kernel = [[0.299, 0.59590059, + 0.2115], [0.587, -0.27455667, -0.52273617], [0.114, -0.32134392, 0.31119955]] @@ -1694,11 +1694,10 @@ def rgb_to_yiq(images): kernel = ops.convert_to_tensor( _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_yiq_to_rgb_kernel = [[1, 1, 1], - [0.95598634, -0.27201283, -1.10674021], +_yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], [0.6208248, -0.64720424, 1.70423049]] @@ -1721,11 +1720,11 @@ def yiq_to_rgb(images): kernel = ops.convert_to_tensor( _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538], - [0.587, -0.28886916, -0.51496512], +_rgb_to_yuv_kernel = [[0.299, -0.14714119, + 0.61497538], [0.587, -0.28886916, -0.51496512], [0.114, 0.43601035, -0.10001026]] @@ -1747,11 +1746,10 @@ def rgb_to_yuv(images): kernel = ops.convert_to_tensor( _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) -_yuv_to_rgb_kernel = [[1, 1, 1], - [0, -0.394642334, 2.03206185], +_yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], [1.13988303, -0.58062185, 0]] @@ -1774,5 +1772,4 @@ def yuv_to_rgb(images): kernel = ops.convert_to_tensor( _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') ndims = images.get_shape().ndims - return math_ops.tensordot(images, kernel, axes=[[ndims-1], [0]]) - + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 0dc1c56e7d..1f1bcfc8f6 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -1905,7 +1905,8 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): bounding_box = constant_op.constant( [0.0, 0.0, 1.0, 1.0], shape=[4], - dtype=dtypes.float32,) + dtype=dtypes.float32, + ) begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( image_size=image_size, bounding_boxes=bounding_box, @@ -3153,43 +3154,37 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): def testInvalidShape(self): # The boxes should be 2D of shape [num_boxes, 4]. - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 2 but is rank 1'): + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 2 but is rank 1"): boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0]) scores = constant_op.constant([0.9]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, 3, 0.5) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) - with self.assertRaisesRegexp( - ValueError, 'Dimension must be 4 but is 3'): + with self.assertRaisesRegexp(ValueError, "Dimension must be 4 but is 3"): boxes = constant_op.constant([[0.0, 0.0, 1.0]]) scores = constant_op.constant([0.9]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, 3, 0.5) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) # The scores should be 1D of shape [num_boxes]. - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 1 but is rank 2'): + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 1 but is rank 2"): boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) scores = constant_op.constant([[0.9]]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, 3, 0.5) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) # The max_output_size should be a scaler (0-D). - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 1"): boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) scores = constant_op.constant([0.9]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, [3], 0.5) + image_ops.non_max_suppression(boxes, scores, [3], 0.5) # The iou_threshold should be a scaler (0-D). - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 2'): + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 2"): boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) scores = constant_op.constant([0.9]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, 3, [[0.5]]) + image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) if __name__ == "__main__": diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 0907ea69eb..3368285bc6 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -547,13 +547,13 @@ def mean_pairwise_squared_error( num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch-1) + num_present_per_batch - 1) sum_diff = math_ops.reduce_sum( diffs, reduction_indices=reduction_indices, keep_dims=True) term2 = 2.0 * _safe_div( math_ops.square(sum_diff), - math_ops.multiply(num_present_per_batch, num_present_per_batch-1)) + math_ops.multiply(num_present_per_batch, num_present_per_batch - 1)) weighted_losses = math_ops.multiply(term1 - term2, weights) loss = math_ops.reduce_sum(weighted_losses) diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py index 573e8c0a0d..bb2069359d 100644 --- a/tensorflow/python/ops/manip_grad.py +++ b/tensorflow/python/ops/manip_grad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Gradients for operators defined in manip_ops.py.""" from __future__ import absolute_import diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py index c5f39784f4..91e15b47b9 100644 --- a/tensorflow/python/ops/manip_ops.py +++ b/tensorflow/python/ops/manip_ops.py @@ -24,10 +24,12 @@ from __future__ import print_function from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops from tensorflow.python.util.all_util import remove_undocumented + # pylint: disable=protected-access -def roll(input, shift, axis): +def roll(input, shift, axis): # pylint: disable=redefined-builtin return _gen_manip_ops.roll(input, shift, axis) + roll.__doc__ = _gen_manip_ops.roll.__doc__ # pylint: enable=protected-access diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py index aa7539ae9f..49d54beb20 100644 --- a/tensorflow/python/ops/nn_grad_test.py +++ b/tensorflow/python/ops/nn_grad_test.py @@ -24,7 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import nn_grad +from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 55fcd176d6..5fa5708114 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 737b923415..f83fdfb17b 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -25,6 +25,7 @@ import sys as _sys # Imports the following modules so that @RegisterGradient get executed. from tensorflow.python.ops import array_grad from tensorflow.python.ops import data_flow_grad +from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad from tensorflow.python.ops import manip_grad from tensorflow.python.ops import sparse_grad @@ -43,11 +44,13 @@ from tensorflow.python.ops.special_math_ops import * # TODO(vrv): Switch to import * once we're okay with exposing the module. from tensorflow.python.ops.confusion_matrix import confusion_matrix from tensorflow.python.ops.control_flow_ops import Assert +from tensorflow.python.ops.control_flow_ops import case +from tensorflow.python.ops.control_flow_ops import cond from tensorflow.python.ops.control_flow_ops import group from tensorflow.python.ops.control_flow_ops import no_op +# pylint: disable=redefined-builtin from tensorflow.python.ops.control_flow_ops import tuple -from tensorflow.python.ops.control_flow_ops import cond -from tensorflow.python.ops.control_flow_ops import case +# pylint: enable=redefined-builtin from tensorflow.python.ops.control_flow_ops import while_loop from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.functional_ops import * @@ -267,35 +270,36 @@ _allowed_symbols = (_allowed_symbols_array_ops + _allowed_symbols_misc + _allowed_symbols_partitioned_variables) -remove_undocumented(__name__, _allowed_symbols, - [_sys.modules[__name__], - _array_ops, - _check_ops, - _clip_ops, - _confusion_matrix, - _control_flow_ops, - _constant_op, - _data_flow_ops, - _functional_ops, - _gradients, - _histogram_ops, - _init_ops, - _io_ops, - _linalg_ops, - _logging_ops, - _manip_ops, - _math_ops, - _numerics, - _parsing_ops, - _partitioned_variables, - _random_ops, - _script_ops, - _session_ops, - _sparse_ops, - _special_math_ops, - _state_ops, - _string_ops, - _template, - _tensor_array_ops, - _variable_scope, - _variables,]) +remove_undocumented(__name__, _allowed_symbols, [ + _sys.modules[__name__], + _array_ops, + _check_ops, + _clip_ops, + _confusion_matrix, + _control_flow_ops, + _constant_op, + _data_flow_ops, + _functional_ops, + _gradients, + _histogram_ops, + _init_ops, + _io_ops, + _linalg_ops, + _logging_ops, + _manip_ops, + _math_ops, + _numerics, + _parsing_ops, + _partitioned_variables, + _random_ops, + _script_ops, + _session_ops, + _sparse_ops, + _special_math_ops, + _state_ops, + _string_ops, + _template, + _tensor_array_ops, + _variable_scope, + _variables, +]) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index fd22164897..bebf1d5e0d 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -235,8 +235,9 @@ def load(sess, tags, export_dir, **saver_kwargs): asset_tensors_dictionary = _get_asset_tensors(export_dir, meta_graph_def_to_load) - main_op_tensor = (_get_main_op_tensor(meta_graph_def_to_load) or - (_get_legacy_init_op_tensor(meta_graph_def_to_load))) + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def_to_load) or + (_get_legacy_init_op_tensor(meta_graph_def_to_load))) if main_op_tensor is not None: sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index fd78f44c99..affa97062a 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -101,8 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def, _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: - saver = saver_lib.Saver(saver_def=input_saver_def, - write_version=checkpoint_version) + saver = saver_lib.Saver( + saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( @@ -126,8 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def, # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver(var_list=var_list, - write_version=checkpoint_version) + saver = saver_lib.Saver( + var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) @@ -237,11 +237,21 @@ def freeze_graph(input_graph, if input_saver: input_saver_def = _parse_input_saver_proto(input_saver, input_binary) freeze_graph_with_def_protos( - input_graph_def, input_saver_def, input_checkpoint, output_node_names, - restore_op_name, filename_tensor_name, output_graph, clear_devices, - initializer_nodes, variable_names_whitelist, variable_names_blacklist, - input_meta_graph_def, input_saved_model_dir, - saved_model_tags.split(","), checkpoint_version=checkpoint_version) + input_graph_def, + input_saver_def, + input_checkpoint, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph, + clear_devices, + initializer_nodes, + variable_names_whitelist, + variable_names_blacklist, + input_meta_graph_def, + input_saved_model_dir, + saved_model_tags.split(","), + checkpoint_version=checkpoint_version) def main(unused_args): diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index 342732465d..91f0061ebc 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -84,9 +84,18 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_meta_graph = checkpoint_meta_graph_file freeze_graph.freeze_graph( - input_graph_path, input_saver_def_path, input_binary, checkpoint_path, - output_node_names, restore_op_name, filename_tensor_name, - output_graph_path, clear_devices, "", "", input_meta_graph, + input_graph_path, + input_saver_def_path, + input_binary, + checkpoint_path, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph_path, + clear_devices, + "", + "", + input_meta_graph, checkpoint_version=saver_write_version) # Now we make sure the variable is now a constant, and that the graph still diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index 2ef612473b..084a4500f8 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -184,8 +184,11 @@ class OptimizeForInferenceTest(test.TestCase): weights_op = constant_op.constant( np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) conv_op = nn_ops.conv2d( - input_op, weights_op, [1, 1, 1, 1], padding="SAME", - data_format=data_format, name="conv_op") + input_op, + weights_op, [1, 1, 1, 1], + padding="SAME", + data_format=data_format, + name="conv_op") mean_op = constant_op.constant( np.array([10, 20]), shape=[2], dtype=dtypes.float32) variance_op = constant_op.constant( diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 5b0a584c10..33f6debbcb 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -39,7 +39,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper from tensorflow.python.framework import ops as ops_lib -from tensorflow.python.platform import app +from tensorflow.python.platform import app # pylint: disable=unused-import from tensorflow.python.saved_model import loader from tensorflow.python.tools import saved_model_utils diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 764f840012..0c1c8e664b 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1750,7 +1750,8 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") if (os.path.isfile(save_path) and - self._write_version != saver_pb2.SaverDef.V1): + self._write_version not in ( + saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)): raise ValueError("The specified path: %s is a file." " Please specify only the path prefix" " to the checkpoint files." % save_path) diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 731fe34273..18a5b89d30 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -111,7 +111,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): # and the same name has been previously used, the scope name will add '_N' # as suffix for unique identifications. validate_shape = val.get_shape().is_fully_defined() - with variable_scope.variable_scope(None, primary.op.name + "/" + name): + prefix = primary.op.name if context.in_graph_mode() else primary._shared_name # pylint: disable=protected-access + with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: with ops.colocate_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index e98c32b614..d7133cfb50 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -19,7 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training import gen_training_ops +from tensorflow.python.training import gen_training_ops # pylint: disable=unused-import # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.training.gen_training_ops import * diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py index 9e60e689d2..d8b9319f66 100644 --- a/tensorflow/python/util/compat_internal.py +++ b/tensorflow/python/util/compat_internal.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functions for Python 2 vs. 3 compatibility that are private to TensorFlow.""" from __future__ import absolute_import @@ -21,9 +20,9 @@ from __future__ import print_function from tensorflow.python.util.compat import as_str_any - def path_to_str(path): - """Returns the file system path representation of a `PathLike` object, else as it is. + """Returns the file system path representation of a `PathLike` object, + else as it is. Args: path: An object that can be converted to path representation. diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index d71938634d..0c642912b1 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -97,11 +97,12 @@ string GetCudnnVersion() { return TF_CUDNN_VERSION; } /* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) { #if defined(ANDROID_TEGRA) - // On Android devices the CUDA version number is not added to the library name. - return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( - "cupti", ""), - GetCudaCuptiLibraryPath()), - dso_handle); + // On Android devices the CUDA version number is not added to the library + // name. + return GetDsoHandle( + FindDsoPath(port::Env::Default()->FormatLibraryFileName("cupti", ""), + GetCudaCuptiLibraryPath()), + dso_handle); #else return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName( "cupti", GetCudaVersion()), diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index fd5d005844..03a675a644 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -185,7 +185,8 @@ do_pylint() { # C0330 bad-continuation # C0301 line-too-long # C0326 bad-whitespace - grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326)' ${OUTPUT_FILE} > ${ERRORS_FILE} + # W0611 unused-import + grep -E '(\[E|\[W0311|\[W0312|\[C0330|\[C0301|\[C0326|\[W0611)' ${OUTPUT_FILE} > ${ERRORS_FILE} N_ERRORS=0 while read -r LINE; do diff --git a/tensorflow/tools/docs/generate_1_0.py b/tensorflow/tools/docs/generate_1_0.py index cdc03fdcac..f4384e0ced 100644 --- a/tensorflow/tools/docs/generate_1_0.py +++ b/tensorflow/tools/docs/generate_1_0.py @@ -53,7 +53,6 @@ if __name__ == '__main__': 'factorization', 'grid_rnn', 'labeled_tensor', - 'ndlstm', 'quantization', 'session_bundle', 'slim', diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 003f972070..34dd419f15 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -215,7 +215,6 @@ def _get_default_do_not_descend_map(): # Block contrib.keras to de-clutter the docs 'keras', 'labeled_tensor', - 'ndlstm', 'quantization', 'session_bundle', 'slim', diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index a9c4a8de42..3189bd09fc 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -70,7 +70,6 @@ py_binary( "//tensorflow/python/eager:eager_pip", "//tensorflow/contrib/summary:summary_test_util", # These targets don't build on Windows yet. Exclude them for now. - # "//tensorflow/contrib/ndlstm", # "//tensorflow/contrib/slim", # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", # "//tensorflow/contrib/specs", @@ -159,7 +158,6 @@ sh_binary( "//tensorflow/contrib/lite/toco:toco", "//tensorflow/contrib/lite/toco/python:toco_wrapper", "//tensorflow/contrib/lite/toco/python:toco_from_protos", - "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/contrib/py2tf:py2tf", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 2002786999..0e6b32bb49 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -181,9 +181,10 @@ def find_files(pattern, root): matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x] -so_lib_paths = [i for i in os.listdir('.') - if os.path.isdir(i) - and fnmatch.fnmatch(i, '_solib_*')] +so_lib_paths = [ + i for i in os.listdir('.') + if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') +] for path in so_lib_paths: matches.extend( diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index bc83dfd6cb..12d3c739cc 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -179,11 +179,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "gemmlowp", urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", - "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/d4d1e29a62192d8defdc057b913ef36ca582ac98.zip", + "https://github.com/google/gemmlowp/archive/d4d1e29a62192d8defdc057b913ef36ca582ac98.zip", ], - sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", - strip_prefix = "gemmlowp-010bb3e71a26ca1d0884a167081d092b43563996", + sha256 = "e2bee7afd3c43028f23dd0d7f85ddd8b21aaf79c572b658e56164ef502b2b9c7", + strip_prefix = "gemmlowp-d4d1e29a62192d8defdc057b913ef36ca582ac98", ) tf_http_archive( diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index ca2d38d687..87a23925c4 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -145,9 +145,9 @@ cc_library( "jpeglib.h", "jsimd.h", "jsimddct.h", - "simd/jsimd.h", "simd/jccolor-altivec.c", "simd/jcgray-altivec.c", + "simd/jcsample.h", "simd/jcsample-altivec.c", "simd/jdcolor-altivec.c", "simd/jdmerge-altivec.c", @@ -157,15 +157,15 @@ cc_library( "simd/jidctfst-altivec.c", "simd/jidctint-altivec.c", "simd/jquanti-altivec.c", - "simd/jsimd_powerpc.c", + "simd/jsimd.h", "simd/jsimd_altivec.h", - "simd/jcsample.h", + "simd/jsimd_powerpc.c", ], hdrs = [ - "simd/jdmrgext-altivec.c", # should have been named .inc - "simd/jccolext-altivec.c", # should have been named .inc - "simd/jcgryext-altivec.c", # should have been named .inc - "simd/jdcolext-altivec.c", # should have been named .inc + "simd/jccolext-altivec.c", # should have been named .inc + "simd/jcgryext-altivec.c", # should have been named .inc + "simd/jdcolext-altivec.c", # should have been named .inc + "simd/jdmrgext-altivec.c", # should have been named .inc ], copts = libjpegturbo_copts, nocopts = libjpegturbo_nocopts, @@ -545,7 +545,6 @@ config_setting( ) config_setting( - name = "linux_ppc64le", - values = {"cpu": "ppc"}, - + name = "linux_ppc64le", + values = {"cpu": "ppc"}, ) |