diff options
author | 2018-05-08 17:14:55 -0700 | |
---|---|---|
committer | 2018-05-08 17:14:55 -0700 | |
commit | 24c9174f84be94043e58ac4536295a3d44d82678 (patch) | |
tree | 92f6cfd82d9ad2c295ec8a45bd7df8d5b5d6ee0f | |
parent | c0fb9413914d983cad2ea6bb4997033a1f0dd722 (diff) | |
parent | 14d5f219f33b1ab8e0a67b84d97204d046adb91f (diff) |
Merge commit for internal changes
341 files changed, 17489 insertions, 6704 deletions
diff --git a/configure.py b/configure.py index fe15bfc1a4..6d9aba61bb 100644 --- a/configure.py +++ b/configure.py @@ -845,8 +845,8 @@ def reformat_version_sequence(version_str, sequence_count): def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use, ' - 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + 'Please specify the CUDA SDK version you want to use. ' + '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 82dbd3cdbc..95b04f9058 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8407,3 +8407,51 @@ TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, } return ret; } + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index e6757c065f..20bdace40f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -87,8 +87,11 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( unsigned char is_mnist, TF_Status* status); // On success, dequeues a tensor from a TF-managed FifoQueue given by -// `tensor_id`, associated with `session`. Caller must call TF_DeleteTensor() -// over the returned tensor. If the queue is empty, this call is blocked. +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. // // Tensors are enqueued via the corresponding TF enqueue op. // TODO(hongm): Add support for `timeout_ms`. @@ -96,6 +99,22 @@ TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, TF_Status* status); +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_<tensor_id>", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_<tensor_id>". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 8026076b9e..e9ed3395c4 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,13 +130,15 @@ class GradientTape { } } - bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids); + bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids, + gtl::ArraySlice<tensorflow::DataType> dtypes); void Watch(int64 tensor_id); void RecordOperation(const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, gtl::ArraySlice<int64> input_tensor_id, + gtl::ArraySlice<tensorflow::DataType> input_dtypes, BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter); @@ -170,12 +172,30 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template <typename Gradient, typename BackwardFunction> bool GradientTape<Gradient, BackwardFunction>::ShouldRecord( - gtl::ArraySlice<int64> tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice<int64> tensor_ids, + gtl::ArraySlice<tensorflow::DataType> dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + return IsDtypeTrainable(dtypes[i]); } } return false; @@ -189,9 +209,11 @@ void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) { template <typename Gradient, typename BackwardFunction> void GradientTape<Gradient, BackwardFunction>::RecordOperation( const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors, - gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function, + gtl::ArraySlice<int64> input_tensor_id, + gtl::ArraySlice<tensorflow::DataType> input_dtypes, + BackwardFunction* backward_function, const std::function<void()>& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { backward_function_deleter(); return; } diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 222e26810a..fd2cf2b67d 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -15,6 +15,7 @@ test_suite( ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", ":test_graph_tfassert_eq_test", + ":test_graph_tfcond_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -55,6 +56,7 @@ genrule( "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", "test_graph_tfassert_eq.pb", + "test_graph_tfcond.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -119,6 +121,17 @@ tf_library( ) tf_library( + name = "test_graph_tfcond", + testonly = 1, + config = "test_graph_tfcond.config.pbtxt", + cpp_class = "CondComp", + graph = "test_graph_tfcond.pb", + tags = [ + "manual", + ], +) + +tf_library( name = "test_graph_tffunction", testonly = 1, config = "test_graph_tffunction.config.pbtxt", @@ -194,6 +207,7 @@ tf_cc_test( ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", ":test_graph_tfassert_eq", + ":test_graph_tfcond", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 67767f55da..9ec7df163b 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir): f.write(saver.as_saver_def().SerializeToString()) +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + +def tfcond(_): + p = array_ops.placeholder(dtypes.bool, name='p_hold') + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + z = control_flow_ops.cond(p, lambda: x, lambda: y) + array_ops.identity(z, name='result') + + def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') @@ -126,14 +142,6 @@ def tfsplits(_): array_ops.identity(y, name='result') -def tfassert_eq(_): - x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = array_ops.placeholder(dtypes.int32, name='y_hold') - control_flow_ops.Assert( - math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') - math_ops.add(x, math_ops.negative(y), name='x_y_diff') - - def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -148,12 +156,13 @@ def main(_): write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) + write_graph(tfcond, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) - write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) - write_graph(tfassert_eq, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt new file mode 100644 index 0000000000..94a01ad4ab --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "p_hold" } + shape {} +} +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 27ba42b31f..309a991fc1 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" @@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); } +TEST(TFCompileTest, Cond) { + CondComp cond; + EXPECT_EQ(cond.arg0_data(), cond.args()[0]); + EXPECT_EQ(cond.arg1_data(), cond.args()[1]); + EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + cond.arg1() = 10; + cond.arg2() = 20; + { + cond.arg0() = true; + const int32 expected_result = cond.arg1(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } + { + cond.arg0() = false; + const int32 expected_result = cond.arg2(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } +} + TEST(TFCompileTest, Gather) { GatherComp gather; EXPECT_EQ(gather.arg0_data(), gather.args()[0]); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 07136d6a74..a6b3ce394c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -261,6 +261,7 @@ cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", + "create_xla_launch_op.h", ], deps = [ ":common", @@ -270,6 +271,29 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "create_xla_launch_op_test", + srcs = [ + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", + ], + deps = [ + ":create_xla_launch_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f..f35e916eb9 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -25,78 +27,189 @@ limitations under the License. namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr<OpKernel>* kernel) { - bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector<int>* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + + private: + int current_index_; + const std::vector<int>* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + bool xla_compile = false; + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector<int>* constant_arg_indices, + std::vector<int>* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector<bool> const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector<bool> const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); - - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr<OpKernel>* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector<int> constant_arg_indices; + std::vector<int> resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = absl::make_unique<XlaLocalLaunchBase>( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 0000000000..98a22e3515 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr<OpKernel>* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 0000000000..bcd5e75c7e --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "absl/memory/memory.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector<FunctionDef>& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = absl::make_unique<FunctionLibraryDefinition>( + OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = absl::make_unique<DeviceMgr>(devices_); + pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector<Device*> devices_; + std::unique_ptr<DeviceMgr> device_mgr_; + std::unique_ptr<FunctionLibraryDefinition> lib_def_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + + std::unique_ptr<OpKernel> kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 049d170fa4..86a9fd3b8e 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -39,15 +39,15 @@ limitations under the License. namespace tensorflow { -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { @@ -57,8 +57,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -90,8 +90,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -124,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } std::map<int, OptionalTensor> variables = - SnapshotResourceVariables(ctx, num_resource_args_); + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client()); @@ -161,7 +161,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map<int, Tensor> constant_args; - for (int i = 0; i < num_constant_args_; ++i) { + for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, @@ -170,8 +170,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context( - num_resource_args_, client, xla_allocator, allocate_xla_tensors); + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -194,6 +194,62 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector<int> constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector<int> resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8f8e646f0f..8dfc4b382d 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,6 +26,41 @@ limitations under the License. namespace tensorflow { +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector<int> constants_; + // Indexes of resource inputs + std::vector<int> resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; + // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -35,26 +70,12 @@ namespace tensorflow { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 60458f6f33..6b83cf67ff 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map<int, OptionalTensor> variables = GetVariables(ctx); - int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - num_resource_args, client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 33e53612b9..0223f97a03 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map<int, OptionalTensor> SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector<int>& variables) { std::map<int, OptionalTensor> snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -112,10 +111,9 @@ ScopedShapedBuffer ExtractSubShapedBuffer( using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 38291b0bd4..a2431253f8 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -31,15 +31,17 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. -std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map<int, OptionalTensor> SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector<int>& variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -72,7 +74,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -92,7 +94,6 @@ class XlaComputationLaunchContext { const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; } private: - int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 922a918973..6b29c82ec1 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -54,7 +54,7 @@ class XlaTensor { // Some Tensors can have complex on-device shapes, including tuple shapes. To // manage the memory for these tensors a ShapedBuffer may be required. - // Return true if this TensorInfo contains a ShapedBuffer. + // Return true if this XlaTensor contains a ShapedBuffer. bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } // Return the contained ShapedBuffer. // REQUIRES: has_shaped_buffer() @@ -62,7 +62,7 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } - // Mutates the TensorInfo to set the ShapedBuffer. + // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = xla::MakeUnique<xla::ScopedShapedBuffer>(std::move(shaped_buffer)); @@ -72,7 +72,7 @@ class XlaTensor { // in on-demand mode to avoid re-copying values from the device if we know the // host value already. - // Return true if this TensorInfo contains a host tensor. + // Return true if this XlaTensor contains a host tensor. bool has_host_tensor() const { return host_tensor_ != nullptr; } // Return the contained host tensor. // REQUIRES: has_host_tensor() diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a94b298f87..9791792f29 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -300,6 +300,10 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -323,7 +327,11 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index bdd0185dfe..5ab1585f8c 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,10 +24,16 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest @@ -43,7 +49,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen0(self): with self.test_scope(): - empty = constant_op.constant([], dtype=dtypes.int32) + empty = constant_op.constant([], dtype=dtypes.float32) result = array_ops.unstack(empty, 0) self.assertTrue(isinstance(result, list)) self.assertEqual(0, len(result)) @@ -51,7 +57,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen1(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 1, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) @@ -60,7 +66,7 @@ class EagerTest(XLATestCase): def testExecuteListOutputLen3(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 3, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) @@ -131,7 +137,105 @@ class EagerTest(XLATestCase): self.assertEqual(2., grads[0][0].numpy()) -if __name__ == "__main__": +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + +if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) googletest.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebdbd1..b6f8390a45 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 6340c22518..a99d4ddc7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -255,7 +255,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)), diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1af9cb6d2a..dbf14f32bc 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -99,6 +99,7 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":xla_data_proto", ":xla_proto", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd889..932cce943f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 044458164f..df262c97bf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/default/thread_annotations.h" @@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr<CompiledLocalComputation*> LocalComputation::Compile( @@ -271,7 +272,7 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -281,8 +282,12 @@ StatusOr<Shape> LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr<LocalComputation*> LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr<Shape> LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique<Shape>(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() { @@ -311,222 +318,236 @@ StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<int64> new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices, tensorflow::gtl::ArraySlice<int64> limit_indices, tensorflow::gtl::ArraySlice<int64> strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) { + std::vector<XlaOp> xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice<LocalOp> elements) { + std::vector<XlaOp> xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice<LocalOp> operands) { + std::vector<XlaOp> xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice<LocalOp> operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice<LocalOp> static_operands) { + std::vector<XlaOp> xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector<XlaOp> static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr<bool> LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice<Literal> parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -534,23 +555,19 @@ StatusOr<std::unique_ptr<Literal>> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5ec097846a..a06b85b4ea 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -97,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr<LocalExecutable> executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr<CompiledLocalComputation*> Compile( const std::vector<Shape>& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr<Shape> GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -135,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr<LocalComputation*> Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand); + std::unique_ptr<Shape> GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr<Shape> GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<int64> new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> start_indices, - tensorflow::gtl::ArraySlice<int64> limit_indices, - tensorflow::gtl::ArraySlice<int64> strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice<int64> slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice<LocalOp> operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice<ComputationDataHandle> elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice<LocalOp> elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<LocalOp> operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice<int64> dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice<int64> dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<LocalOp> static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice<int64> window_dimensions, tensorflow::gtl::ArraySlice<int64> window_strides, tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr<bool> IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr<bool> IsConstant(const LocalOp& operand); - StatusOr<std::unique_ptr<Literal> > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice<Literal> parameters); + StatusOr<LocalComputation*> BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -338,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b8cce5a5f7..04c56bbba9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice<int64> <- sequence of int -// ArraySlice<ComputationDataHandle> <- sequence of int +// ArraySlice<LocalOp> <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector<Literal> <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -301,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice<LocalOp> -%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle> - (std::vector<ComputationDataHandle> temps) { +%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalOp>( + std::vector<LocalOp> temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -934,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b87..1d5b75d1be 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -335,20 +335,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -535,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +632,7 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +641,13 @@ class ComputationBuilder(object): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 9c362d8cad..aa3a6261e0 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -26,6 +26,7 @@ xla_proto_library( xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], + visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) @@ -200,7 +201,22 @@ tf_cc_test( cc_library( name = "hlo_evaluator", - srcs = ["hlo_evaluator.cc"], + srcs = [ + "hlo_evaluator.cc", + "hlo_evaluator_typed_visitor.h", + "hlo_evaluator_typed_visitor_bfloat16.cc", + "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex64.cc", + "hlo_evaluator_typed_visitor_double.cc", + "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int32.cc", + "hlo_evaluator_typed_visitor_int64.cc", + "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint32.cc", + "hlo_evaluator_typed_visitor_uint64.cc", + "hlo_evaluator_typed_visitor_uint8.cc", + ], hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", @@ -370,6 +386,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2467,6 +2484,7 @@ tf_cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", @@ -2478,6 +2496,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 8e785de68c..4ec79a0244 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -291,6 +291,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -912,6 +914,134 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } +StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.rhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0 || + dnums.rhs_batch_dimensions_size() != 0 || + dot->shape().dimensions_size() != 2) { // dot output 2D + VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; + return nullptr; + } + + // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). + // Currently a Gather is a DynamicSlice. + auto is_dynamic_slice_constant_combination = + [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { + // First operand is a DynamicSlice(Constant). + if (a->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + auto* dynamic_slice_op = a->operand(0); + if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { + return false; + } + // Second operand is a Constant. + if (b->opcode() != HloOpcode::kConstant) { + return false; + } + // The DynamicSlice output is a vector. + const Shape& dynamic_slice_shape = a->shape(); + if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { + return false; + } + // Constant size is the same before and after slice in the contracting + // dimension, otherwise we either must precompute for all possible slice + // indices or dot is invalid. + const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); + if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != + dynamic_slice_shape.dimensions(a_contracting_dimension)) { + return false; + } + return true; + }; + + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + + if (!is_dynamic_slice_constant_combination( + lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && + !is_dynamic_slice_constant_combination( + rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { + VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " + "dot(ctB, DS(ctA)), where the two constants have equal " + "contracting dimensions."; + return nullptr; + } + + // LHS is DynamicSlice: + // input: dot(DS(ctA), ctB)) + // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. + // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. + + // RHS is DynamicSlice: + // input: dot(ctA, DS(ctB)) + // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). + // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. + + bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + + // ctA: + HloInstruction* left_operand = + lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; + // ctB: + HloInstruction* right_operand = + lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); + // Build ctA x ctB. + const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); + const int n = + right_operand->shape().dimensions(1 - rhs_contracting_dimension); + auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( + memoized_shape, left_operand, right_operand, dnums)); + // Get pair {start, 0} or {0, start}. + HloInstruction* original_start_indices = + lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); + // Position of start: + int index_of_non_zero_start = lhs_is_dynamic_slice + ? 1 - lhs_contracting_dimension + : 1 - rhs_contracting_dimension; + // Position of zero: + int index_of_zero_start = 1 - index_of_non_zero_start; + + // Slice out start and 0 components and reorder if necessary. + auto indices_type = original_start_indices->shape().element_type(); + Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + HloInstruction* non_zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_non_zero_start}, + {index_of_non_zero_start + 1}, {1})); + HloInstruction* zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_zero_start}, + {index_of_zero_start + 1}, {1})); + HloInstruction* new_start_indices = + lhs_is_dynamic_slice + ? computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {non_zero_start, zero_start}, 0)) + : computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {zero_start, non_zero_start}, 0)); + + // Build DynamicSlice(ctA x ctB). + const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; + const int new_slice_n = lhs_is_dynamic_slice ? n : 1; + auto* memoized_lookup = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + dot->shape(), memoized_inst, new_start_indices, + {new_slice_m, new_slice_n})); + + return memoized_lookup; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -941,6 +1071,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_concat_optimized); } + // Simplify dot(ConstA, Gather(Index, ConstB)) to: + // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately + // batched version of dot. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); + if (dot_of_gather_optimized) { + VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " + "gather(i, dot*(constA, constB))"; + return ReplaceInstruction(dot, dot_of_gather_optimized); + } + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index d0c99bf818..4e082877c7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2963,5 +2963,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); + +struct DotOfGatherTestSpec { + int64 m; + int64 k; + int64 n; + int s; // start index for dynamic slice on the non-contracting dimension + int64 lcd; // left contracting dimension + int64 rcd; // right contracting dimension + bool neg; // is negative testcase +}; + +class DotOfGatherSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface<DotOfGatherTestSpec> {}; + +// input: dot(DS(ctA), ctB)) +// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. +// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. +TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.m); + + // For negative tests, increase k of the dynamic slice argument to prevent the + // optimization (constants ctA, ctB must have equal contracting dimensions). + int64 k_increase = spec.neg ? 5 : 0; + int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + int32 start_row = (spec.lcd == 0) ? 0 : spec.s; + int32 start_col = (spec.lcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1<int32>({start_row, start_col}))); + int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + + int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = 1; + int64 dot_col_size = spec.n; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +// input: dot(ctA, DS(ctB)) +// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}). +// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. +TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.n); + + int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + // For negative tests increase k of the dynamic slice argument to prevent the + // optimization + int64 k_increase = spec.neg ? 5 : 0; + int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + int32 start_row = (spec.rcd == 0) ? 0 : spec.s; + int32 start_col = (spec.rcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1<int32>({start_row, start_col}))); + int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = spec.m; + int64 dot_col_size = 1; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +std::vector<DotOfGatherTestSpec> DotOfGatherPositiveNegativeTests() { + std::vector<DotOfGatherTestSpec> positives = { + // "Classical dot", i.e. matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to + // dot(ct, ct) before DotOfGather optimization kicks in. + // Contract on rows: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + // Reverse matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + // Contract on columns: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + }; + std::vector<DotOfGatherTestSpec> all; + for (int i = 0; i < positives.size(); i++) { + DotOfGatherTestSpec positive_test = positives[i]; + all.push_back(positive_test); + DotOfGatherTestSpec negative_test = positive_test; + negative_test.neg = true; + all.push_back(negative_test); + } + return all; +} + +INSTANTIATE_TEST_CASE_P( + DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, + ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 91ed6e427a..3d2e24ca14 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -535,7 +535,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index a98e85a151..46fe060817 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -158,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { EXPECT_EQ(dot, computation->root_instruction()); } -TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { - HloComputation::Builder builder(TestName()); - HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); - HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[1024,256] parameter(1) + exponential = s32[1024,256] exponential(arg1) + transpose = s32[256,1024] transpose(exponential), dimensions={1,0} + ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }, TransposeFolding::NeverFoldTranspose); - EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[256,1] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[1,256] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_TransposeFusion_LHS_NonDefault) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[256,1] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); } class OpcodeFusionTest : public InstructionFusionTest { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index e8117377e6..6c642080c3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -139,13 +139,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - // dot is a kDot or a kTransposeDot fusion node. In the latter case, if - // it represents X @ X, it may have just one operand. - if (dot->operand_count() > 1) { - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - } + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 801c523908..8db4a0650d 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -522,16 +522,16 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // namespace -DotOpEmitter::DotOpEmitter( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -542,18 +542,18 @@ DotOpEmitter::DotOpEmitter( target_machine_features_(target_machine_features) {} /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, addend_array, - executable_run_options_value, ir_builder, - hlo_module_config, target_machine_features); + DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, + addend_array, executable_run_options_value, + ir_builder, hlo_module_config, + target_machine_features); return dot_emitter.Emit(); } @@ -578,7 +578,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.m == 1) { bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major; if (rhs_effectively_row_major) { k = mat_mult_dims.k; m = mat_mult_dims.n; @@ -594,7 +594,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.n == 1) { bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major; if (lhs_effectively_column_major) { m = mat_mult_dims.m; k = mat_mult_dims.k; @@ -741,16 +741,10 @@ tensorflow::Status DotOpEmitter::Emit() { // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = 0; - if (ShapeUtil::Rank(lhs_shape) >= 2) { - lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); - } - int64 rhs_reduction_dimension = 0; - if (ShapeUtil::Rank(rhs_shape) >= 2) { - rhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); - } + int64 lhs_reduction_dimension = + dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = + dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -986,8 +980,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; - bool transpose_lhs = transpose_lhs_; - bool transpose_rhs = transpose_rhs_; + bool transpose_lhs = mat_mult_dims.lhs_non_canonical; + bool transpose_rhs = mat_mult_dims.rhs_non_canonical; if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); @@ -1015,12 +1009,16 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; + const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + + return { + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1090,27 +1088,16 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 47e0924334..a20bf2f9db 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -56,16 +56,15 @@ class DotOpEmitter { // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. static tensorflow::Status EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); private: - DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, const llvm_ir::IrArray& target_array, + DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -114,8 +113,14 @@ class DotOpEmitter { // True if the LHS matrix column major. bool lhs_column_major; + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; + // True if the RHS matrix column major. bool rhs_column_major; + + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -132,8 +137,6 @@ class DotOpEmitter { } const HloInstruction& dot_; - const bool transpose_lhs_; - const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6347ee2a2a..55e5aa5063 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -827,13 +827,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { "Dot with multiple contracting dimensions not implemented."); } - if (dnums.lhs_contracting_dimensions(0) != - std::min(lhs->shape().dimensions_size() - 1, 1) || - dnums.rhs_contracting_dimensions(0) != 0) { - return Unimplemented( - "Dot with non-standard contracting dimensions not implemented."); - } - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -850,8 +843,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, /*addend_array=*/nullptr, + *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, target_machine_features_); } @@ -2086,44 +2078,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); - if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - DCHECK(root->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - fusion->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - fusion->operand(rhs_parameter->parameter_number()); - - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64})); - - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - - Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayFor(fusion); - VLOG(2) << "HandleFusion kTransposeDot: "; - VLOG(2) << " lhs operand: " - << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); - VLOG(2) << " rhs operand: " - << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); - VLOG(2) << " target: " - << llvm_ir::DumpToString(*target_array.GetBasePointer()); - - // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *root, root->operand(0)->IsRank2Transpose(), - root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); - return Status::OK(); - } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, - assignment_)) { + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); @@ -2166,9 +2121,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { GetIrArrayFor(fusion->operand(addend_param_number))); TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index fb28280fad..47e8405ff2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -127,7 +127,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index c4c56c5692..41ee45f55f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -197,22 +197,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // We don't put any data in these buffers, because (in theory, anyway) the // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - se::port::StatusOr<DeviceMemoryBase> input_buf = + StatusOr<DeviceMemoryBase> maybe_input_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(input_shape)); - se::port::StatusOr<DeviceMemoryBase> filter_buf = + StatusOr<DeviceMemoryBase> maybe_filter_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(filter_shape)); - se::port::StatusOr<DeviceMemoryBase> output_buf = + StatusOr<DeviceMemoryBase> maybe_output_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(output_shape)); - if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || + !maybe_output_buf.ok()) { LOG(WARNING) << "Couldn't allocate space for input/filter/output of convolution " << instr->ToString() << ". Falling back to default algorithm."; return nullopt; } + DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); + DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); + DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the inputs + // might affect performance if e.g. the inputs contain denormals, and this is + // easy enough. + if (!stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone() + .ok()) { + LOG(WARNING) + << "Couldn't zero out input/filter/output buffer for convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; @@ -225,12 +245,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - input_buf.ValueOrDie(), filter_buf.ValueOrDie(), - output_buf.ValueOrDie(), &scratch_allocator, window, - dnums, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 0ec12f52d8..f996fe486d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -221,8 +221,7 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, double alpha, + const Shape& output_shape, double alpha, const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), @@ -231,8 +230,6 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), alpha_(alpha) {} tensorflow::Status GemmThunk::ExecuteOnStream( @@ -284,10 +281,13 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const MatrixDescriptor lhs_descriptor = - make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); - const MatrixDescriptor rhs_descriptor = - make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + const DotDimensionNumbers& dim_nums = + hlo_instruction()->dot_dimension_numbers(); + + const MatrixDescriptor lhs_descriptor = make_descriptor( + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + const MatrixDescriptor rhs_descriptor = make_descriptor( + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index a18f425bc3..f42cbf9e94 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -35,15 +35,13 @@ namespace gpu { class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs <dot> rhs) * alpha" using - // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should - // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is - // a constant. + // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - double alpha, const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, + const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -69,8 +67,6 @@ class GemmThunk : public Thunk { const Shape rhs_shape_; const Shape output_shape_; - const bool transpose_lhs_; - const bool transpose_rhs_; const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index ece9fa04dc..6436abc06c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -65,9 +65,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -193,11 +193,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -259,24 +259,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb..c5eb721185 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -48,6 +48,19 @@ bool IsFusile(const HloInstruction& hlo) { } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6df..9fb06b0a24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f..6c9a805ad6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -253,5 +253,61 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 532d436ee8..96199035b9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -78,18 +78,14 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - return true; - } - - if (hlo.opcode() == HloOpcode::kFusion && hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { // Try to find the dot inside the output fusion node. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 9f37235d32..83d90296df 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2206,65 +2206,37 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. - false, // Do not transpose LHS. - false, // Do not transpose RHS. 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique<GemmThunk>( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - alpha->literal().Get<double>({0}), // alpha. - inst); - } else { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique<GemmThunk>( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - 1.0, // Alpha. - inst); + CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); } + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique<GemmThunk>( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*mul), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + alpha->literal().Get<double>({0}), // alpha. + inst); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8c98956f1a..b42767dfd5 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -41,9 +41,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -60,9 +60,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -91,24 +91,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b..1f7c1cffd3 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88f..17e43c3cb8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited, &post_order); @@ -723,18 +728,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr<HloComputation> HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +768,19 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( } } - std::unordered_map<HloInstruction*, HloInstruction*> clone_map; std::vector<std::unique_ptr<HloInstruction>> instructions; std::unique_ptr<HloInstruction> new_instr = nullptr; for (auto instr : postorder) { std::vector<HloInstruction*> new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +788,24 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c..9898355625 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -291,11 +291,17 @@ class HloComputation { const std::function<Status(const HloInstruction*)>& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr<HloComputation> Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr<HloComputation> Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +311,9 @@ class HloComputation { std::unique_ptr<HloComputation> CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 9a89888480..ed3b654851 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -269,7 +269,7 @@ StatusOr<HloInstruction*> BroadcastZeros( StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature( ArraySlice<const Shape*> domain, const Shape& range, tensorflow::StringPiece name) { - HloComputation::Builder b(name.ToString()); + HloComputation::Builder b{std::string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1071f5b184..e7425c8ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -53,19 +53,6 @@ namespace { using tensorflow::gtl::ArraySlice; using tensorflow::gtl::FlatSet; -using tensorflow::gtl::optional; - -template <typename T> -struct is_complex_t : public std::false_type {}; - -template <> -struct is_complex_t<complex64> : public std::true_type {}; - -template <typename T> -struct is_complex64_t : public std::false_type {}; - -template <> -struct is_complex64_t<complex64> : public std::true_type {}; template <typename OperandT> StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, @@ -147,2092 +134,48 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>( return std::move(result); } -template <typename ReturnT, typename NativeT> -StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl( - HloInstruction* instruction, - const std::function<ReturnT(NativeT)>& unary_op, - const Literal& operand_literal) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - return unary_op(operand_literal.Get<NativeT>(multi_index)); - })); - return std::move(result); -} - -// For one particular placement of a window in a base shape (the placement is -// represented as `window_count_index`), iterates inside the window. Translates -// the window index into base index. If the base index is within bound, call `f` -// with the base index. -void IterateThroughWindow( - const Shape& window_shape, const Window& window, const Shape& base_shape, - const ArraySlice<int64>& window_count_index, - const std::function<void(const std::vector<int64>&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); - DimensionVector window_index(rank); - std::fill(window_index.begin(), window_index.end(), 0); - do { - std::vector<int64> base_index(rank); - bool out_of_bound = false; - for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - f(base_index); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); -} - -// Creates a vector of multipliers which can be used to create a linear index -// into shape. -// -// Given the multidimensional index {i1, ..., iN} and -// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply -// -// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. -// -// This lets you calculate LI given the multidimensional indices in any order. -DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); - int64 scale = 1; - for (auto dim : LayoutUtil::MinorToMajor(shape)) { - v[dim] = scale; - scale *= shape.dimensions(dim); - } - return v; -} - } // namespace -template <typename ReturnT, typename ElementwiseT> -class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { - public: - explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} - - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function<ReturnT(ReturnT)> ConvertUnaryFunction( - const std::function<ElementwiseT(ElementwiseT)>& unary_op) { - return [&unary_op](ReturnT arg) { - return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg))); - }; - } - std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction( - const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& - binary_op) { - return [&binary_op](ReturnT arg1, ReturnT arg2) { - return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1), - static_cast<ElementwiseT>(arg2))); - }; - } - std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction( - const std::function<ElementwiseT(ElementwiseT, ElementwiseT, - ElementwiseT)>& ternary_op) { - return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { - return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1), - static_cast<ElementwiseT>(arg2), - static_cast<ElementwiseT>(arg3))); - }; - } - - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); - } - - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive types. - - template <typename NativeT, - typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = - nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return std::abs(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(abs->operand(0)); - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[abs], - (ElementWiseUnaryOpImpl<float, NativeT>( - abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, - operand_literal))); - - return Status::OK(); - } - - Status HandleAbs(HloInstruction* abs) override { - // If the operand is of C64 type, the return type of abs will be F32. - // However, ElementwiseT would still be the return type, F32, and thus - // specifying the ElementwiseT explicitly as C64 is needed below. - if (abs->operand(0)->shape().element_type() == C64) { - return HandleAbs<complex64>(abs); - } - return HandleAbs<ElementwiseT>(abs); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { - return std::round(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); - } - - Status HandleRound(HloInstruction* round) override { - return HandleRound<ReturnT>(round); - } - - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector<int64> broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get<ReturnT>(broadcast_indices); - }); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { - return std::ceil(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); - } - - Status HandleCeil(HloInstruction* ceil) override { - return HandleCeil<ReturnT>(ceil); - } - - Status HandleConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, - parent_->GetEvaluatedLiteralFor(operand).Convert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleBitcastConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, - parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleExp(HloInstruction* exp) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { - return std::exp(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { - return std::floor(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); - } - - Status HandleFloor(HloInstruction* floor) override { - return HandleFloor<ReturnT>(floor); - } - - Status HandleLog(HloInstruction* log) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { - return std::log(elem_operand); - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if< - std::is_integral<NativeT>::value && - !std::is_same<NativeT, bool>::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return ~elem_operand; - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if<std::is_same<NativeT, bool>::value>::type* = - nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); - } - - Status HandleNot(HloInstruction* not_) override { - return HandleNot<ElementwiseT>(not_); - } - - template <typename NativeT, - typename std::enable_if< - std::is_signed<NativeT>::value && - !std::is_floating_point<NativeT>::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - using type = typename std::make_unsigned<NativeT>::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { - return NativeT(-type(elem_operand)); - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if< - !std::is_signed<NativeT>::value || - std::is_floating_point<NativeT>::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp( - negate, [](ElementwiseT elem_operand) { return -elem_operand; })); - return Status::OK(); - } - - Status HandleNegate(HloInstruction* negate) override { - return HandleNegate<ReturnT>(negate); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - return (ElementwiseT(0) < elem_operand) - - (elem_operand < ElementwiseT(0)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ElementwiseT(0) - : elem_operand / abs_val; - })); - return Status::OK(); - } - - Status HandleSign(HloInstruction* sign) override { - return HandleSign<ReturnT>(sign); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], - ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return std::atan2(lhs_elem, rhs_elem); - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<!std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); - } - - Status HandleAtan2(HloInstruction* atan2) override { - return HandleAtan2<ElementwiseT>(atan2); - } - - Status HandleTanh(HloInstruction* tanh) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { - return std::tanh(elem_operand); - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if< - std::is_signed<NativeT>::value && - !std::is_floating_point<NativeT>::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned<NativeT>::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<std::is_unsigned<NativeT>::value || - std::is_floating_point<NativeT>::value || - is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); - return Status::OK(); - } - - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply<ElementwiseT>(multiply); - } - - Status HandleSubtract(HloInstruction* subtract) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); - return Status::OK(); - } - - Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; - })); - return Status::OK(); - } - - Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem / rhs_elem; - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value>::type* = - nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::max(lhs, rhs); - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); - } - - Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum<ElementwiseT>(maximum); - } - - template <typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value>::type* = - nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::min(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); - } - - Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum<ElementwiseT>(minimum); - } - - Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); - } - - Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder<ElementwiseT>(remainder); - } - - template <typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value>::type* = - nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el & rhs_el; - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); - } - - Status HandleAnd(HloInstruction* and_) override { - return HandleAnd<ElementwiseT>(and_); - } - - template <typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value>::type* = - nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el | rhs_el; - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); - } - - Status HandleOr(HloInstruction* or_) override { - return HandleOr<ElementwiseT>(or_); - } - - template <typename NativeT, - typename std::enable_if< - std::is_integral<NativeT>::value && - !std::is_same<NativeT, bool>::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction* shl) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shl], - ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0 - : (lhs_elem << rhs_elem); - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if<!std::is_integral<NativeT>::value || - std::is_same<NativeT, bool>::value>::type* = - nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); - } - - Status HandleShiftLeft(HloInstruction* shl) override { - return HandleShiftLeft<ElementwiseT>(shl); - } - template <typename NativeT, - typename std::enable_if< - std::is_integral<NativeT>::value && - !std::is_same<NativeT, bool>::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction* shr) { - typedef typename std::make_signed<NativeT>::type SignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - SignedT lhs_signed = static_cast<SignedT>(lhs_elem); - if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { - return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0; - } else { - return lhs_signed >> rhs_elem; - } - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if<!std::is_integral<NativeT>::value || - std::is_same<NativeT, bool>::value>::type* = - nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); - } - - Status HandleShiftRightArithmetic(HloInstruction* shra) override { - return HandleShiftRightArithmetic<ElementwiseT>(shra); - } - - template <typename NativeT, - typename std::enable_if< - std::is_integral<NativeT>::value && - !std::is_same<NativeT, bool>::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction* shr) { - typedef typename std::make_unsigned<NativeT>::type UnsignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - // If shift amount is greater than the number of bits, then return 0. - if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { - return static_cast<NativeT>(0); - } - return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template <typename NativeT, - typename std::enable_if<!std::is_integral<NativeT>::value || - std::is_same<NativeT, bool>::value>::type* = - nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); - } - - Status HandleShiftRightLogical(HloInstruction* shrl) override { - return HandleShiftRightLogical<ElementwiseT>(shrl); - } - - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleClamp(HloInstruction* clamp) { - std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> - clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); - }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); - } - - Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp<ElementwiseT>(clamp); - } - - Status HandleSelect(HloInstruction* select) override { - CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); - std::function<ReturnT(bool, ReturnT, ReturnT)> select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); - return Status::OK(); - } - - Status HandleReverse(HloInstruction* reverse) override { - const auto result_shape = reverse->shape(); - const auto reverse_dimensions = reverse->dimensions(); - - auto operand = reverse->operand(0); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReverseShape(operand->shape(), - reverse_dimensions)); - - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); - - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> out_index) { - std::vector<int64> from_index(out_index.begin(), out_index.end()); - for (const int64 dim : reverse_dimensions) { - from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; - } - return operand_literal.Get<ReturnT>(from_index); - })); - - parent_->evaluated_[reverse] = std::move(result); - return Status::OK(); - } - - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); - const auto& window = conv->window(); - const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); - - const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); - CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 0); - CHECK_EQ(window.dimensions_size(), num_spatial_dims); - - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - std::vector<int64> window_dimension_sizes; - for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); - } - - const Shape& window_shape = - ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - - DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); - DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); - - auto lhs_literal_data = lhs_literal.data<ReturnT>(); - auto rhs_literal_data = rhs_literal.data<ReturnT>(); - - auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, - &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](ArraySlice<int64> out_index) { - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - - ElementwiseT result_val = static_cast<ElementwiseT>(0); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), - 0); - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 lhs_linear_index = 0; - lhs_linear_index += out_index[output_batch_dim] * - lhs_dim_multipliers[input_batch_dim]; - lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - - int64 rhs_linear_index = 0; - rhs_linear_index += out_index[output_z_dim] * - rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - - result_val += - static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) * - static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]); - } - cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - - return static_cast<ReturnT>(result_val); - }; - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func)); - - parent_->evaluated_[conv] = std::move(result); - return Status::OK(); - } - - Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - - const auto& dnums = dot->dot_dimension_numbers(); - - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(dot->shape()); - - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - std::vector<int64> lhs_non_contracting_dims; - for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); - } - } - - std::vector<int64> rhs_non_batch_non_contracting_dims; - FlatSet<int64> batch_dims_set(dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); - } - } - - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> result_index) { - ElementwiseT result_val = static_cast<ElementwiseT>(0); - - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set separately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; - } - - // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; - - result_val += - static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) * - static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index)); - } - - return static_cast<ReturnT>(result_val); - })); - - parent_->evaluated_[dot] = std::move(result); - return Status::OK(); - } - - Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); - // Padding value must be scalar. - CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), - pad->padding_config().dimensions_size()); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferPadShape( - /*operand_shape=*/pad->operand(0)->shape(), - /*padding_value_shape=*/pad->operand(1)->shape(), - /*padding_config=*/pad->padding_config())); - CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - // Create new HLO of padded shape with padding value. - ReturnT scalar = - parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); - auto result = Literal::CreateFromShape(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( - [&scalar](ArraySlice<int64> multi_index) { return scalar; })); - - const Literal& evaluated_operand = - parent_->GetEvaluatedLiteralFor(pad->operand(0)); - - std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0); - - // Loop through each element of the operand, assign them to the - // corresponding index of the resulting padded literal. - const PaddingConfig& pad_config = pad->padding_config(); - - auto func = [&](ArraySlice<int64> input_index) { - for (auto i = 0; i < input_index.size(); ++i) { - // Interior padding occurs logically before edge padding, so in the case - // of negative edge padding elements are removed from the - // interior-padded operand. - target_index[i] = - pad_config.dimensions(i).edge_padding_low() + - input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); - - // Account for negative low and high padding: skip assignment if the - // any target index is out of range. - if (!(target_index[i] >= 0 && - target_index[i] < pad->shape().dimensions(i))) { - return true; - } - } - result->Set<ReturnT>(target_index, - evaluated_operand.Get<ReturnT>(input_index)); - return true; - }; - - std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1); - - ShapeUtil::ForEachIndex( - evaluated_operand.shape(), zero_base, - AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); - - parent_->evaluated_[pad] = std::move(result); - return Status::OK(); - } - - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - auto operand = dynamic_slice->operand(0); - auto start_indices = dynamic_slice->operand(1); - auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice<int32>(operand_literal, start_indices_literal, - result_shape)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice<int64>(operand_literal, start_indices_literal, - result_shape)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice<uint32>(operand_literal, start_indices_literal, - result_shape)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice<uint64>(operand_literal, start_indices_literal, - result_shape)); - } break; - default: - LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - auto operand = dynamic_update_slice->operand(0); - auto update = dynamic_update_slice->operand(1); - auto start_indices = dynamic_update_slice->operand(2); - auto result_shape = dynamic_update_slice->shape(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice<int32>(operand_literal, update_literal, - start_indices_literal)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice<int64>(operand_literal, update_literal, - start_indices_literal)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice<uint32>(operand_literal, update_literal, - start_indices_literal)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice<uint64>(operand_literal, update_literal, - start_indices_literal)); - } break; - default: - LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - template <typename NativeT> - StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) { - auto operands = map->operands(); - HloComputation* computation = map->to_apply(); - - auto result = Literal::CreateFromShape(map->shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - std::vector<std::unique_ptr<Literal>> arg_literals; - arg_literals.reserve(operands.size()); - - // Construct scalar literal parameters to be passed to the map - // computation. - for (auto operand : operands) { - const Literal& arg_literal = - parent_->GetEvaluatedLiteralFor(operand); - - auto curr_val = arg_literal.Get<NativeT>(multi_index); - auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val); - - arg_literals.push_back(std::move(curr_val_literal)); - } - - std::unique_ptr<Literal> computed_result = - embedded_evaluator - .Evaluate<std::unique_ptr<Literal>>(*computation, - arg_literals) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - return computed_result->Get<ReturnT>({}); - })); - return std::move(result); - } - - Status HandleMap(HloInstruction* map) override { - switch (map->operand(0)->shape().element_type()) { - case PRED: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map)); - break; - } - case U8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map)); - break; - } - case U32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map)); - break; - } - case U64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map)); - break; - } - case S8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map)); - break; - } - case S32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map)); - break; - } - case S64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map)); - break; - } - case F16: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], - MapImpl<Eigen::half>(map)); - break; - } - case F32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map)); - break; - } - default: - LOG(FATAL) << "HandleMap: unhandled primitive type for " - "input operand: " - << PrimitiveType_Name( - map->operand(0)->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - ArraySlice<int64> dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get<ReturnT>({}); - - auto result = Literal::CreateFromShape(reduce->shape()); - - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); - std::vector<int64> arg_dim_steps(arg_dimensions.size()); - std::vector<int64> arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Map each dimension in the result to a dimension in arg that isn't - // being reduced. - std::vector<int64> result_to_arg_index; - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index.push_back(i); - } - } - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - ReturnT result_val = init_scalar; - - std::vector<int64> base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](ArraySlice<int64> input_index) { - computed_result += arg_literal.Get<float>(input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast<ReturnT>(computed_result); - } - auto func = [&](ArraySlice<int64> input_index) { - auto curr_val = arg_literal.Get<ReturnT>(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); - auto result_val_literal = Literal::CreateR0<ReturnT>(result_val); - std::vector<const Literal*> args = {result_val_literal.get(), - curr_val_literal.get()}; - - std::unique_ptr<Literal> computed_result = - embedded_evaluator.Evaluate<const Literal*>(*function, args) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get<ReturnT>({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); - } - - bool IsScalarAdd(HloComputation* computation) { - HloInstruction* instruction = computation->root_instruction(); - if (instruction->opcode() == HloOpcode::kAdd && - computation->num_parameters() == 2) { - const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); - return lhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(lhs->shape()) && - rhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; - } - return false; - } - - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { - auto operand = select_and_scatter->operand(0); - auto source = select_and_scatter->operand(1); - const Window& window = select_and_scatter->window(); - - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get<ReturnT>({}); - - auto result = Literal::CreateFromShape(select_and_scatter->shape()); - - // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate<ReturnT>( - [&](ArraySlice<int64> output_index) { return init_scalar; })); - - std::vector<int64> window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - HloComputation* select = select_and_scatter->select(); - HloComputation* scatter = select_and_scatter->scatter(); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - - int64 rank = ShapeUtil::Rank(operand_literal.shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); - do { - // For each element in `source`, we place a window in `operand`. For each - // window placement, we iterate inside the window twice: - // - // 1. Find the selected index by applying `select` function to all - // elements. E.g., If the `select` function is GreaterEqual, the first - // iteration through the window finds the biggest value and returns its - // index. - // - // 2. Using the selected index, scatter value from `source` to result. We - // do this by iterating through the window, and compare each index with - // the selected index. - optional<ReturnT> selected_val; - optional<std::vector<int64>> selected_index; - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector<int64>& operand_index) { - auto curr_val = operand_literal.Get<ReturnT>(operand_index); - if (!selected_val) { - selected_val = curr_val; - selected_index = operand_index; - } - const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); - const auto selected_val_literal = - Literal::CreateR0<ReturnT>(*selected_val); - - const std::vector<const Literal*> args = { - selected_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr<Literal> computed_result = - embedded_evaluator.Evaluate<const Literal*>(*select, args) - .ConsumeValueOrDie(); - bool selected = !computed_result->Get<bool>({}); - if (selected) { - selected_val = curr_val; - selected_index = operand_index; - } - embedded_evaluator.ResetVisitStates(); - }); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector<int64>& operand_index) { - if (std::equal(operand_index.begin(), operand_index.end(), - selected_index->begin())) { - auto source = source_literal.Get<ReturnT>(source_index); - auto scattered = result->Get<ReturnT>(operand_index); - const auto source_literal = Literal::CreateR0<ReturnT>(source); - const auto scattered_literal = - Literal::CreateR0<ReturnT>(scattered); - - const std::vector<const Literal*> args = { - source_literal.get(), scattered_literal.get()}; - std::unique_ptr<Literal> computed_result = - embedded_evaluator.Evaluate<const Literal*>(*scatter, args) - .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get<ReturnT>({})); - // Clear visit states so that the we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - } - }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); - - parent_->evaluated_[select_and_scatter] = std::move(result); - return Status::OK(); - } - - Status HandleReduceWindow(HloInstruction* reduce_window) override { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferReduceWindowShape( - /*operand_shape=*/reduce_window->operand(0)->shape(), - /*init_value=*/reduce_window->operand(1)->shape(), window, - /*to_apply_shape=*/function->ComputeProgramShape())); - TF_RET_CHECK( - ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) - << "return shape is set to: " - << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanStringWithLayout(inferred_return_shape); - - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); - VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); - VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get<ReturnT>({}); - - auto result = Literal::CreateFromShape(reduce_window->shape()); - - // Creates a Shape object from window, for iteration below. - std::vector<int64> window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> output_index) { - ReturnT result_val = init_scalar; - - std::fill(window_index.begin(), window_index.end(), 0); - std::fill(operand_index.begin(), operand_index.end(), 0); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), output_index, - [&](const std::vector<int64>& operand_index) { - auto curr_val = operand_literal.Get<ReturnT>(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0<ReturnT>(curr_val); - const auto result_val_literal = - Literal::CreateR0<ReturnT>(result_val); - const std::vector<const Literal*> args = { - result_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr<Literal> computed_result = - embedded_evaluator.Evaluate<const Literal*>(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get<ReturnT>({}); - }); - - return result_val; - })); - - parent_->evaluated_[reduce_window] = std::move(result); - return Status::OK(); - } - - Status HandleSlice(HloInstruction* slice) override { - auto operand = slice->operand(0); - const Shape& shape = slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferSliceShape( - operand->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const int64 rank = ShapeUtil::Rank(operand->shape()); - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](ArraySlice<int64> out_index) { - DimensionVector operand_index(rank); - for (int64 i = 0; i < rank; ++i) { - operand_index[i] = - slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); - } - return operand_literal.Get<ReturnT>(operand_index); - }; - - auto result = Literal::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); - parent_->evaluated_[slice] = std::move(result); - return Status::OK(); - } - - // Enable CLZ only for int32 and uint32. - template < - typename NativeT, - typename std::enable_if< - (std::is_floating_point<NativeT>::value || - std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) && - !(std::is_same<NativeT, uint32>::value || - std::is_same<NativeT, int32>::value)>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); - } - - template <typename NativeT, - typename std::enable_if< - std::is_same<NativeT, uint32>::value || - std::is_same<NativeT, int32>::value>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], - ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - return 31 - tensorflow::Log2Floor(elem_operand); - })); - return Status::OK(); - } - - Status HandleClz(HloInstruction* clz) override { - return HandleClz<ElementwiseT>(clz); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value || - is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); - } - - Status HandleSin(HloInstruction* sin) override { - return HandleSin<ElementwiseT>(sin); - } - - template <typename NativeT, typename std::enable_if<std::is_floating_point< - NativeT>::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value || - is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); - } - - Status HandleCos(HloInstruction* cos) override { - return HandleCos<ElementwiseT>(cos); - } - - template <typename NativeT, typename std::enable_if<std::is_same< - float, NativeT>::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); - - // Code is based on the CPU/GPU implementation in LLVM-emitting code. - // - // Bits in float type: - // mantissa : bits [0:22] - // exponent : bits [23:30] - // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); - - // Compute rounding bias for round-to-nearest with ties to even. - // This is equal to a base value of 0111... plus one bit if the last - // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = - x_last_mantissa_bit + base_rounding_bias; - - // Add rounding bias, and mask out truncated bits. Note that the - // case where adding the rounding bias overflows into the exponent - // bits is correct; the non-masked mantissa bits will all be zero, - // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - value_as_int = value_as_int + x_rounding_bias; - value_as_int = value_as_int & truncation_mask; - } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the - // most- significant bit -- is equal to 1.0f for all exponent sizes. - // Adding 2^(n-1)-1 to this gives us the highest non-infinite - // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from - // this gives us the lowest' exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n - // is (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); - const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); - - // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; - - // Force to zero or infinity if overflow or underflow. (Note that - // this truncates all denormal values to zero, rather than rounding - // them.) - value_as_int = x_overflows ? x_signed_inf : value_as_int; - value_as_int = x_underflows ? x_signed_zero : value_as_int; - } - - float reduced_result = tensorflow::bit_cast<float>(value_as_int); - if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 - ? elem - : std::numeric_limits<float>::infinity(); - } - return reduced_result; - })); - return Status::OK(); - } - - template <typename NativeT, typename std::enable_if<std::is_same< - double, NativeT>::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); - } - - template < - typename NativeT, - typename std::enable_if<std::is_integral<NativeT>::value || - is_complex_t<NativeT>::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); - } - - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return HandleReducePrecision<ElementwiseT>(reduce_precision); - } - - private: - template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data<IndexT>(); - std::vector<int64> start(start_indices_typed.begin(), - start_indices_typed.end()); - - std::vector<int64> operand_indices(start.size()); - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - for (int64 i = 0; i < operand_indices.size(); ++i) { - CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); - } - - auto result = operand_literal.Get<ReturnT>(operand_indices); - return result; - })); - - return std::move(result); - } - - template <typename IndexT> - StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); - auto start_indices_typed = start_indices_literal.data<IndexT>(); - const auto rank = ShapeUtil::Rank(result->shape()); - std::vector<int64> start(rank, 0); - for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); - } - std::vector<int64> result_index(rank, 0); - - auto func = [&](ArraySlice<int64> update_index) { - std::transform(update_index.begin(), update_index.end(), start.begin(), - result_index.begin(), std::plus<int64>()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus<int64>()); - result->Set<ReturnT>(result_index, - update_literal.Get<ReturnT>(update_index)); - return true; - }; - - std::vector<int64> base(update_literal.shape().dimensions_size(), 0); - std::vector<int64> step(update_literal.shape().dimensions_size(), 1); - ShapeUtil::ForEachIndex(update_literal.shape(), base, - AsInt64Slice(update_literal.shape().dimensions()), - step, func); - - return std::move(result); - } - - StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp( - HloInstruction* instruction, - const std::function<ElementwiseT(ElementwiseT)>& unary_op) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - TF_ASSIGN_OR_RETURN( - auto result_literal, - (ElementWiseUnaryOpImpl<ReturnT, ReturnT>( - instruction, ConvertUnaryFunction(unary_op), operand_literal))); - - return std::move(result_literal); - } - - StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp( - HloInstruction* instruction, - const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& - binary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - return ConvertBinaryFunction(binary_op)( - lhs_literal.Get<ReturnT>(multi_index), - rhs_literal.Get<ReturnT>(multi_index)); - })); - return std::move(result); - } - - template <typename LhsType, typename RhsType, typename EhsType> - StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp( - HloInstruction* instruction, - const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate<ReturnT>([&](ArraySlice<int64> multi_index) { - return ternary_op(lhs_literal.Get<LhsType>(multi_index), - rhs_literal.Get<RhsType>(multi_index), - ehs_literal.Get<EhsType>(multi_index)); - })); - - return std::move(result); - } - - template <typename NativeT> - static bool IsShiftOutOfBounds(NativeT rhs) { - typedef typename std::make_unsigned<NativeT>::type UnsignedT; - UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; - UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs); - return rhs_unsigned >= lhs_size_unsigned; - } - - HloEvaluator* parent_; -}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique<TypedVisitor<bool>>(this); - typed_visitors_[U8] = MakeUnique<TypedVisitor<uint8>>(this); + typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this); + typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this); typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); }); - typed_visitors_[U32] = MakeUnique<TypedVisitor<uint32>>(this); - typed_visitors_[U64] = MakeUnique<TypedVisitor<uint64>>(this); - typed_visitors_[S8] = MakeUnique<TypedVisitor<int8>>(this); + typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this); + typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this); + typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this); typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); }); - typed_visitors_[S32] = MakeUnique<TypedVisitor<int32>>(this); - typed_visitors_[S64] = MakeUnique<TypedVisitor<int64>>(this); - typed_visitors_[F16] = MakeUnique<TypedVisitor<Eigen::half, float>>(this); - typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this); - typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this); - typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this); + typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this); + typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this); + typed_visitors_[F16] = + MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this); + typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this); + typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this); + typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. - typed_visitors_[BF16] = MakeUnique<TypedVisitor<bfloat16, float>>(this); + typed_visitors_[BF16] = + MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this); typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); } @@ -3034,7 +977,7 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. // This would also handle output array of tuple types as the DefaultAction - // would go through the TypedVisitor which doesn't handle tuples. + // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get<bool>({})) { evaluated_[select] = on_true.CloneToUnique(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c0dcee0c3e..cc5676ea7b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -109,19 +109,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { substitutions); protected: - // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this rule, notably: - // - HandleCompare and HandleIsFinite: where the resulting literal type is - // always boolean. - // These operations are handled outside of the parent HloEvaluator handlers - // instead of from within TypedVisitor. + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this + // class. // - // Type params: - // - ReturnT: The type of input and output of each operation. - // - ElementwiseT: The type in which internal computation are done. - template <typename ReturnT, typename ElementwiseT = ReturnT> - class TypedVisitor; + // A straightforward implementation would be to make it a nested class + // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor + // lives as a separate class with its own header because its template gets + // instantiated many times and we want to use extern templates to shard out + // the compilation of those instantiations across multiple cc files. + template <typename ReturnT, typename ElementwiseT> + friend class HloEvaluatorTypedVisitor; // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. @@ -169,6 +166,33 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; private: + template <typename ReturnT, typename NativeT> + static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function<ReturnT(NativeT)>& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = Literal::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + return unary_op(operand_literal.Get<NativeT>(multi_index)); + })); + return std::move(result); + } + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h new file mode 100644 index 0000000000..f1cb363478 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -0,0 +1,2102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// TODO(b/79274244): We'd like these type traits to live inside of +// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that +// crashes clang in the frontend. +// +// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is +// a "private" header that's not exposed outside of hlo_evaluator.cc. +template <typename T> +using is_complex_t = std::is_same<T, complex64>; +template <typename T> +using is_complex64_t = std::is_same<T, complex64>; + +// Templated DfsHloVisitor for use by HloEvaluator. +// +// Typically ReturnT here indicates the resulting literal type of each evaluated +// Handle* method of a TypedVisitor. There are however a few notable exceptions +// to this rule, notably: +// - HandleCompare and HandleIsFinite: where the resulting literal type is +// always boolean. +// These operations are handled outside of the parent HloEvaluator handlers +// instead of from within TypedVisitor. +// +// Type params: +// - ReturnT: The type of input and output of each operation. +// - ElementwiseT: The type in which internal computation are done. +// +// This a logically a private part of HloEvaluator. It lives in this header +// file rather than in hlo_evaluator.cc because we use extern templates and a +// bunch of independent cc files to speed up compiling the many instantiations +// of this class. +template <typename ReturnT, typename ElementwiseT = ReturnT> +class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} + + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function<ReturnT(ReturnT)> ConvertUnaryFunction( + const std::function<ElementwiseT(ElementwiseT)>& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast<ReturnT>(unary_op(static_cast<ElementwiseT>(arg))); + }; + } + std::function<ReturnT(ReturnT, ReturnT)> ConvertBinaryFunction( + const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast<ReturnT>(binary_op(static_cast<ElementwiseT>(arg1), + static_cast<ElementwiseT>(arg2))); + }; + } + std::function<ReturnT(ReturnT, ReturnT, ReturnT)> ConvertTernaryFunction( + const std::function<ElementwiseT(ElementwiseT, ElementwiseT, + ElementwiseT)>& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast<ReturnT>(ternary_op(static_cast<ElementwiseT>(arg1), + static_cast<ElementwiseT>(arg2), + static_cast<ElementwiseT>(arg3))); + }; + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive type. + + template <typename NativeT, + typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<std::is_signed<NativeT>::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex64_t<NativeT>::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (HloEvaluator::ElementWiseUnaryOpImpl<float, NativeT>( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs<complex64>(abs); + } + return HandleAbs<ElementwiseT>(abs); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound<ReturnT>(round); + } + + Status HandleBroadcast(HloInstruction* broadcast) override { + parent_->evaluated_[broadcast] = + Literal::CreateFromShape(broadcast->shape()); + auto output = parent_->evaluated_[broadcast].get(); + const Literal& operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector<int64> broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand_to_broadcast.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand_to_broadcast.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand_to_broadcast.shape().dimensions(i)); + } + + return output->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get<ReturnT>(broadcast_indices); + }); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil<ReturnT>(ceil); + } + + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor<ReturnT>(floor); + } + + Status HandleLog(HloInstruction* log) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_) override { + return HandleNot<ElementwiseT>(not_); + } + + template <typename NativeT, + typename std::enable_if< + std::is_signed<NativeT>::value && + !std::is_floating_point<NativeT>::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned<NativeT>::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if< + !std::is_signed<NativeT>::value || + std::is_floating_point<NativeT>::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate<ReturnT>(negate); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ElementwiseT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign<ReturnT>(sign); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<!std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2<ElementwiseT>(atan2); + } + + Status HandleTanh(HloInstruction* tanh) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if< + std::is_signed<NativeT>::value && + !std::is_floating_point<NativeT>::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned<NativeT>::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<std::is_unsigned<NativeT>::value || + std::is_floating_point<NativeT>::value || + is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply<ElementwiseT>(multiply); + } + + Status HandleSubtract(HloInstruction* subtract) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum<ElementwiseT>(maximum); + } + + template <typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum<ElementwiseT>(minimum); + } + + Status HandlePower(HloInstruction* power) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder<ElementwiseT>(remainder); + } + + template <typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd<ElementwiseT>(and_); + } + + template <typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr<ElementwiseT>(or_); + } + + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return IsShiftOutOfBounds<NativeT>(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft<ElementwiseT>(shl); + } + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed<NativeT>::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + SignedT lhs_signed = static_cast<SignedT>(lhs_elem); + if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { + return lhs_signed < 0 ? static_cast<SignedT>(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic<ElementwiseT>(shra); + } + + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned<NativeT>::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds<NativeT>(rhs_elem)) { + return static_cast<NativeT>(0); + } + return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical<ElementwiseT>(shrl); + } + + template < + typename NativeT, + typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function<ElementwiseT(ElementwiseT, ElementwiseT, ElementwiseT)> + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return std::fmin(high, std::fmax(value, low)); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp<ElementwiseT>(clamp); + } + + Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function<ReturnT(bool, ReturnT, ReturnT)> select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementwiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + auto operand = reverse->operand(0); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = Literal::CreateFromShape(result_shape); + + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> out_index) { + std::vector<int64> from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get<ReturnT>(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + std::vector<int64> window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); + } + + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); + + auto lhs_literal_data = lhs_literal.data<ReturnT>(); + auto rhs_literal_data = rhs_literal.data<ReturnT>(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data]( + tensorflow::gtl::ArraySlice<int64> out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + ElementwiseT result_val = static_cast<ElementwiseT>(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < + lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + result_val += + static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) * + static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]); + } + cnt : {} + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return static_cast<ReturnT>(result_val); + }; + + auto result = Literal::CreateFromShape(result_shape); + TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func)); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + } + + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = Literal::CreateFromShape(dot->shape()); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector<int64> lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector<int64> rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet<int64> batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> result_index) { + ElementwiseT result_val = static_cast<ElementwiseT>(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; + } + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; + + result_val += + static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) * + static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index)); + } + + return static_cast<ReturnT>(result_val); + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); + auto result = Literal::CreateFromShape(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) { + return scalar; + })); + + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set<ReturnT>(target_index, + evaluated_operand.Get<ReturnT>(input_index)); + return true; + }; + + std::vector<int64> zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector<int64> step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice<int32>(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice<int64>(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice<uint32>(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice<uint64>(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice<int32>(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice<int64>(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice<uint32>(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice<uint64>(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + template <typename NativeT> + StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = Literal::CreateFromShape(map->shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + std::vector<std::unique_ptr<Literal>> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get<NativeT>(multi_index); + auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr<Literal> computed_result = + embedded_evaluator + .Evaluate<std::unique_ptr<Literal>>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get<ReturnT>({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<bool>(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint8>(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint32>(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<uint64>(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int8>(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int32>(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<int64>(map)); + break; + } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl<Eigen::half>(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<float>(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<double>(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl<complex64>(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get<ReturnT>({}); + + auto result = Literal::CreateFromShape(reduce->shape()); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector<int64> arg_dim_steps(arg_dimensions.size()); + std::vector<int64> arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector<int64> result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + ReturnT result_val = init_scalar; + + std::vector<int64> base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { + computed_result += arg_literal.Get<float>(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast<ReturnT>(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { + auto curr_val = arg_literal.Get<ReturnT>(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); + auto result_val_literal = Literal::CreateR0<ReturnT>(result_val); + std::vector<const Literal*> args = {result_val_literal.get(), + curr_val_literal.get()}; + + std::unique_ptr<Literal> computed_result = + embedded_evaluator.Evaluate<const Literal*>(*function, args) + .ConsumeValueOrDie(); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get<ReturnT>({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + } + + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get<ReturnT>({}); + + auto result = Literal::CreateFromShape(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> output_index) { + return init_scalar; + })); + + std::vector<int64> window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + DimensionVector source_index(rank); + + std::fill(source_index.begin(), source_index.end(), 0); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional<ReturnT> selected_val; + tensorflow::gtl::optional<std::vector<int64>> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector<int64>& operand_index) { + auto curr_val = operand_literal.Get<ReturnT>(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val); + const auto selected_val_literal = + Literal::CreateR0<ReturnT>(*selected_val); + + const std::vector<const Literal*> args = { + selected_val_literal.get(), curr_val_literal.get()}; + std::unique_ptr<Literal> computed_result = + embedded_evaluator.Evaluate<const Literal*>(*select, args) + .ConsumeValueOrDie(); + bool selected = !computed_result->Get<bool>({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector<int64>& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get<ReturnT>(source_index); + auto scattered = result->Get<ReturnT>(operand_index); + const auto source_literal = Literal::CreateR0<ReturnT>(source); + const auto scattered_literal = + Literal::CreateR0<ReturnT>(scattered); + + const std::vector<const Literal*> args = { + source_literal.get(), scattered_literal.get()}; + std::unique_ptr<Literal> computed_result = + embedded_evaluator.Evaluate<const Literal*>(*scatter, args) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get<ReturnT>({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get<ReturnT>({}); + + auto result = Literal::CreateFromShape(reduce_window->shape()); + + // Creates a Shape object from window, for iteration below. + std::vector<int64> window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector<int64>& operand_index) { + auto curr_val = operand_literal.Get<ReturnT>(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0<ReturnT>(curr_val); + const auto result_val_literal = + Literal::CreateR0<ReturnT>(result_val); + const std::vector<const Literal*> args = { + result_val_literal.get(), curr_val_literal.get()}; + std::unique_ptr<Literal> computed_result = + embedded_evaluator.Evaluate<const Literal*>(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get<ReturnT>({}); + }); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + } + + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); + const Shape& shape = slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferSliceShape( + operand->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); + TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const int64 rank = ShapeUtil::Rank(operand->shape()); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) { + DimensionVector operand_index(rank); + for (int64 i = 0; i < rank; ++i) { + operand_index[i] = + slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); + } + return operand_literal.Get<ReturnT>(operand_index); + }; + + auto result = Literal::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func)); + parent_->evaluated_[slice] = std::move(result); + return Status::OK(); + } + + // Enable CLZ only for int32 and uint32. + template < + typename NativeT, + typename std::enable_if< + (std::is_floating_point<NativeT>::value || + std::is_integral<NativeT>::value || is_complex_t<NativeT>::value) && + !(std::is_same<NativeT, uint32>::value || + std::is_same<NativeT, int32>::value)>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + return InvalidArgument("Unsupported type for Clz"); + } + + template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, uint32>::value || + std::is_same<NativeT, int32>::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 31 - tensorflow::Log2Floor(elem_operand); + })); + return Status::OK(); + } + + Status HandleClz(HloInstruction* clz) override { + return HandleClz<ElementwiseT>(clz); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value || + is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin<ElementwiseT>(sin); + } + + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value || + is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos<ElementwiseT>(cos); + } + + template <typename NativeT, typename std::enable_if<std::is_same< + float, NativeT>::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast<float>(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits<float>::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template <typename NativeT, typename std::enable_if<std::is_same< + double, NativeT>::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if<std::is_integral<NativeT>::value || + is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision<ElementwiseT>(reduce_precision); + } + + private: + // Creates a vector of multipliers which can be used to create a linear index + // into shape. + // + // Given the multidimensional index {i1, ..., iN} and + // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply + // + // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. + // + // This lets you calculate LI given the multidimensional indices in any order. + static DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; + } + + // For one particular placement of a window in a base shape (the placement is + // represented as `window_count_index`), iterates inside the window. + // Translates the window index into base index. If the base index is within + // bound, call `f` with the base index. + static void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice<int64>& window_count_index, + const std::function<void(const std::vector<int64>&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector<int64> base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } + + template <typename IndexT> + StatusOr<std::unique_ptr<Literal>> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + auto start_indices_typed = start_indices_literal.data<IndexT>(); + std::vector<int64> start(start_indices_typed.begin(), + start_indices_typed.end()); + + std::vector<int64> operand_indices(start.size()); + + auto result = Literal::CreateFromShape(result_shape); + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + for (int64 i = 0; i < operand_indices.size(); ++i) { + CHECK_GE(multi_index[i] + start[i], 0); + // Mod is only used here to be consistent with the existing + // backends' behavior. + operand_indices[i] = (multi_index[i] + start[i]) % + operand_literal.shape().dimensions(i); + } + + auto result = operand_literal.Get<ReturnT>(operand_indices); + return result; + })); + + return std::move(result); + } + + template <typename IndexT> + StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.CloneToUnique(); + auto start_indices_typed = start_indices_literal.data<IndexT>(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector<int64> start(rank, 0); + for (int64 i = 0; i < rank; ++i) { + // All other implementations currently wrap-around the index, so this + // should do so as well. + start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); + start[i] += (start[i] < 0) * result->shape().dimensions(i); + } + std::vector<int64> result_index(rank, 0); + + auto func = [&](tensorflow::gtl::ArraySlice<int64> update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus<int64>()); + // Same as above, wrap-around only to match other implementations' + // semantics. + std::transform(result_index.begin(), result_index.end(), + result->shape().dimensions().begin(), result_index.begin(), + std::modulus<int64>()); + result->Set<ReturnT>(result_index, + update_literal.Get<ReturnT>(update_index)); + return true; + }; + + std::vector<int64> base(update_literal.shape().dimensions_size(), 0); + std::vector<int64> step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + + StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function<ElementwiseT(ElementwiseT)>& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (HloEvaluator::ElementWiseUnaryOpImpl<ReturnT, ReturnT>( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); + } + + StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>& + binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = Literal::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get<ReturnT>(multi_index), + rhs_literal.Get<ReturnT>(multi_index)); + })); + return std::move(result); + } + + template <typename LhsType, typename RhsType, typename EhsType> + StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp( + HloInstruction* instruction, + const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = Literal::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(result->Populate<ReturnT>( + [&](tensorflow::gtl::ArraySlice<int64> multi_index) { + return ternary_op(lhs_literal.Get<LhsType>(multi_index), + rhs_literal.Get<RhsType>(multi_index), + ehs_literal.Get<EhsType>(multi_index)); + })); + + return std::move(result); + } + + template <typename NativeT> + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned<NativeT>::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast<UnsignedT>(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + + HloEvaluator* parent_; +}; + +// These extern templates prevent users of this class from implicitly +// instantiating it. We explicitly instantiate this class in the various +// hlo_evaluator_typed_visitor*.cc files. +extern template class HloEvaluatorTypedVisitor<bool>; +extern template class HloEvaluatorTypedVisitor<uint8>; +extern template class HloEvaluatorTypedVisitor<uint32>; +extern template class HloEvaluatorTypedVisitor<uint64>; +extern template class HloEvaluatorTypedVisitor<int8>; +extern template class HloEvaluatorTypedVisitor<int32>; +extern template class HloEvaluatorTypedVisitor<int64>; +extern template class HloEvaluatorTypedVisitor<Eigen::half, float>; +extern template class HloEvaluatorTypedVisitor<float>; +extern template class HloEvaluatorTypedVisitor<double>; +extern template class HloEvaluatorTypedVisitor<complex64>; +extern template class HloEvaluatorTypedVisitor<bfloat16, float>; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc new file mode 100644 index 0000000000..39c352dfb9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<bfloat16, float>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc new file mode 100644 index 0000000000..289b40fa06 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<bool>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc new file mode 100644 index 0000000000..9cb4eb921f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<complex64>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc new file mode 100644 index 0000000000..5e6252fbf8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<double>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc new file mode 100644 index 0000000000..ee793ae77b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<float>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc new file mode 100644 index 0000000000..038d9d39e4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<Eigen::half, float>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc new file mode 100644 index 0000000000..b1952ca619 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<int32>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc new file mode 100644 index 0000000000..0cbaffb40b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<int64>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc new file mode 100644 index 0000000000..6f4bf2a392 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<int8>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc new file mode 100644 index 0000000000..10235447e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<uint32>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc new file mode 100644 index 0000000000..8abeaa6ffc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<uint64>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc new file mode 100644 index 0000000000..6dabd1c176 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor<uint8>; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bb4db89f0a..b6b0387672 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), - label_(label.ToString()), + label_(std::string(label)), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "<br/>", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "<br/>", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "<br/>", s); } @@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "<br/>"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector<string> lines; @@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3..fc8e1468ac 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e114..857cd39adb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -109,6 +109,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -437,7 +438,7 @@ HloInstruction::CreateCrossReplicaSum( << "Outfeed shape " << shape << " must be compatible with operand shape " << operand->shape(); instruction->AppendOperand(operand); - instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_config_ = std::string(outfeed_config); instruction->outfeed_shape_ = shape; return instruction; } @@ -792,23 +793,11 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot fusions, -// since those fusions are really just describing a type of dot rather than -// generating a novel computation. -static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { - switch (fusion_kind) { - case HloInstruction::FusionKind::kTransposeDot: - return "dot_fusion"; - default: - return "fusion"; - } -} - /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -824,7 +813,7 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; @@ -1167,7 +1156,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->custom_call_target_ = custom_call_target.ToString(); + instruction->custom_call_target_ = std::string(custom_call_target); return instruction; } @@ -1179,7 +1168,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->channel_name_ = channel_name.ToString(); + instruction->channel_name_ = std::string(channel_name); instruction->cost_estimate_ns_ = cost_estimate_ns; return instruction; } @@ -1231,12 +1220,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr<HloInstruction> clone; @@ -1342,7 +1334,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1415,9 +1408,15 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1481,15 +1480,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr<HloInstruction> HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr<HloInstruction> clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1529,6 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap<HloInstruction*> old_to_new; - std::list<std::unique_ptr<HloInstruction>> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector<HloInstruction*> new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair<const HloInstruction*, ShapeIndex> HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -2172,6 +2110,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } @@ -2357,6 +2298,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2328,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2487,8 +2430,6 @@ string HloInstruction::ToCategory() const { return "input fusion"; case FusionKind::kOutput: return "output fusion"; - case FusionKind::kTransposeDot: - return "dot"; case FusionKind::kCustom: return "custom fusion"; } @@ -2971,6 +2912,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast<HloInstruction*>(const_instruction); @@ -3270,8 +3212,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kInput"; case HloInstruction::FusionKind::kOutput: return "kOutput"; - case HloInstruction::FusionKind::kTransposeDot: - return "kTransposeDot"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3288,9 +3228,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( if (kind_name == "kOutput") { return HloInstruction::FusionKind::kOutput; } - if (kind_name == "kTransposeDot") { - return HloInstruction::FusionKind::kTransposeDot; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e..14be58d069 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -66,6 +66,7 @@ class HloPrintOptions { : print_large_constants_(false), print_subcomputation_references_(true), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), @@ -77,6 +78,7 @@ class HloPrintOptions { .set_print_large_constants(true) .set_print_subcomputation_references(true) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); @@ -99,12 +101,18 @@ class HloPrintOptions { return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -141,6 +149,7 @@ class HloPrintOptions { return print_subcomputation_references_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -151,6 +160,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_subcomputation_references_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; @@ -167,7 +177,6 @@ class HloInstruction { kOutput, // Op's output is fused into the op itself. // REQUIRES: At least one operand buffer must be able // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. kCustom, // Custom category for backend-specific fusions that // do not match any of the more specific ones. }; @@ -643,6 +652,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -1157,23 +1168,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map<const HloInstruction*, HloInstruction*>; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr<HloInstruction> CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector<HloComputation*>& called_computations() const { @@ -1245,7 +1263,7 @@ class HloInstruction { // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } + void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1262,6 +1280,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1314,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector<int64>& outer_dimension_partitions() const { @@ -1510,6 +1542,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 5b65b1152c..909cdc0b62 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1102,7 +1102,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + {dot, reshape}, HloInstruction::FusionKind::kLoop); auto fusion2 = fusion->Clone(); const HloInstruction* root = fusion->fused_expression_root(); @@ -1169,7 +1169,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + {dot, b_t}, HloInstruction::FusionKind::kLoop); auto fusion = computation->CreateFusionInstruction( {add, nested_fusion}, HloInstruction::FusionKind::kOutput); @@ -1246,13 +1246,6 @@ TEST_F(HloInstructionTest, Stringification) { auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); - HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - - EXPECT_EQ( - fusion->ToString(options), - "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 69deac263e..7e4b883435 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -17,10 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { +using ::tensorflow::str_util::Join; + bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -195,6 +198,41 @@ void HloShardingMatcher::DescribeTo(std::ostream* os) const { } } +bool HloDotWithContractingDimsMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + + const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); + if (dim_nums.lhs_contracting_dimensions_size() != 1 || + dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong lhs_contracting_dimensions (got {" + << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" + << lhs_contracting_dim_ << "})"; + return false; + } + + if (dim_nums.rhs_contracting_dimensions_size() != 1 || + dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong rhs_contracting_dimensions (got {" + << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" + << rhs_contracting_dim_ << "})"; + return false; + } + + return true; +} + +void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with lhs_contracting_dims={" << lhs_contracting_dim_ + << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}"; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 5175736a25..c33bdadf1c 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -131,6 +131,27 @@ class HloShardingMatcher tensorflow::gtl::optional<HloSharding> sharding_; }; +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractingDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractingDimsMatcher( + ::testing::Matcher<const HloInstruction*> lhs, + ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -158,7 +179,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -310,6 +330,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() { new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); } +inline ::testing::Matcher<const ::xla::HloInstruction*> Dot( + ::testing::Matcher<const HloInstruction*> lhs_matcher, + ::testing::Matcher<const HloInstruction*> rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher<const ::xla::HloInstruction*> Dot( + ::testing::Matcher<const HloInstruction*> lhs_matcher, + ::testing::Matcher<const HloInstruction*> rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractingDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index f2463060b7..016cc01e33 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -165,5 +166,41 @@ TEST(HloMatchersTest, ShardingMatcher) { "has incorrect sharding (expected: {maximal device=0})"); } +TEST(HloMatchersTest, DotMatcher) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0)); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c7a7192867..5308fb5848 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast<HloInstruction*>(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, bool uniquify_names) { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f9674df812..1604a72612 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -217,6 +217,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr<HloInstruction*> LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5120775737..d8f1ab916b 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -90,7 +90,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = name().ToString() + ": pipeline start"; + string prefix = std::string(name()) + ": pipeline start"; bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +98,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), - "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + std::string(name()), "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(std::string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -121,7 +121,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - name().ToString(), pass->name().ToString()); + std::string(name()), std::string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 1a767628f6..23ace5afea 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -430,6 +430,15 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler( return ListScheduler::Run(computation, points_to_analysis, size_function); } +StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector<const HloInstruction*>{post_order.begin(), + post_order.end()}; +} + StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -459,7 +468,22 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( size_function)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - if (list_memory <= dfs_memory) { + TF_ASSIGN_OR_RETURN( + std::vector<const HloInstruction*> post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 post_order_memory, + MinimumMemoryForComputation(computation, post_order_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + if (post_order_memory < std::min(list_memory, dfs_memory)) { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; + + } else if (list_memory <= dfs_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 068e68383d..fcb006f818 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -55,6 +55,12 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function); +// Naive Post Order scheduler +StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd..096ebb7946 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return tensorflow::Status::OK(); } @@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), @@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), @@ -571,22 +571,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fc24acd271..fb36d3a0d6 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(computation_name.ToString()), + : computation_name_(std::string(computation_name)), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -47,9 +47,10 @@ class HumanReadableProfileBuilder { tensorflow::StringPiece category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back( - {op_name.ToString(), short_name.ToString(), category.ToString(), cycles, - flop_count, transcendental_count, bytes_accessed, optimal_seconds}); + op_infos_.push_back({std::string(op_name), std::string(short_name), + std::string(category), cycles, flop_count, + transcendental_count, bytes_accessed, + optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index dc1a39e9fa..6bb2ca19fe 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,6 +28,25 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// These nodes can always be duplicated into consumers, even if +// InstructionFusion::may_duplicate_ is false. +// +// In general these should be nodes that get *cheaper* the more they're +// duplicated (and fused into consumers). +// +// TODO(jlebar): Duplicating instructions when we have a variable called "may +// duplicate" that's equal to false is not pretty. +bool IsAlwaysDuplicable(const HloInstruction& instruction) { + // We are always willing to duplicate a widening type-conversion instruction + // if it means we can fuse the convert into a consumer. This allows the + // consumer to read less memory, which is almost always a performance win. + return instruction.opcode() == HloOpcode::kConvert && + ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction.shape()); +} +} // namespace + /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -418,9 +437,11 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && - (is_expensive_(*producer) || !may_duplicate_)) { + (!may_duplicate_ || is_expensive_(*producer)) && + !IsAlwaysDuplicable(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index e78b99a80c..6dd8fa1ab0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -21,6 +21,8 @@ limitations under the License. namespace xla { +namespace op = xla::testing::opcode_matchers; + using InstructionFusionTest = HloTestBase; TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { @@ -124,7 +126,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); // Make sure the add hasn't been duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { @@ -291,4 +293,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, + WideningConvertsAreAlwaysDuplicableIntoConsumers) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY Test { + p0 = f16[100] parameter(0) + c = f32[100] convert(p0) + add = f32[100] add(c, c) + ROOT mul = f32[100] multiply(c, c) + })") + .ValueOrDie(); + + // The convert should be fused into the add and mul, even though may_duplicate + // is false, because it's always beneficial to fuse/duplicate widening + // converts into consumers. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion(op::Parameter())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 68c99256a2..79dfd1e409 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -173,9 +173,9 @@ bool HasUniqueFusedUseOfOperandAt( // (2) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root // at operand 0. Or... -// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. // @@ -209,17 +209,13 @@ bool CanShareOperandBufferWithUser( user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. + // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); + operand->opcode() == HloOpcode::kDot; }); if (add_operand_it == add->operands().end()) { return false; @@ -314,17 +310,13 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. + // Check if one operand of kAdd fused root is kDot, or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); + operand->opcode() == HloOpcode::kDot; }); if (add_operand_it == add->operands().end()) { return false; diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index f8b309488e..c01b52df62 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -303,48 +303,6 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { *dataflow_analysis_)); } -TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); - auto b_t = builder.AddInstruction( - HloInstruction::CreateTranspose(data_shape, b, {1, 0})); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - - auto nested_fusion = computation_->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); - - auto fusion = computation_->CreateFusionInstruction( - {add, nested_fusion}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused transpose-dot-add should be share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 3312a88844..7323abeb20 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress( } CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); - std::vector<llvm::Value*> actual_index; - bool is_implicit_broadcast = false; - // We perform broadcasting when the operand shape has dimension(s) of size - // 1. In this case we fix the index value for that dimension to zero. This - // effectively broadcasts along this dimension. - for (int64 i = 0; i < index.size(); ++i) { - auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); - is_implicit_broadcast |= dim == 1; - } - - if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( @@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( {index.linear()}, llvm_ir::AsStringRef(name)); } + std::vector<llvm::Value*> actual_index; + for (int64 i = 0; i < index.size(); ++i) { + // When dimension i is of size 1, LLVM optimization is able to replace + // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to + // produce better code in some cases. + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + } + // "base_ptr_" has the type of "<ir_type_for_its_shape>*" // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element // should be computed by diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f74bcb0b79..3a6a7c25f4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -53,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : prefix.ToString()); + string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 586f6ef7a9..d3bc47e61e 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl { HloInstructionPattern<OperandType, OperandImpl> operand_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template <typename Previous> +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + // A pattern that matches HloInstructions. template <typename HloInstructionType, typename Impl> class HloInstructionPattern { @@ -807,6 +831,16 @@ class HloInstructionPattern { matched_inst_); } + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern<HloInstructionType, + HloInstructionPatternFusionKindImpl<Impl>> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern<HloInstructionType, + HloInstructionPatternFusionKindImpl<Impl>>( + HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index c88157c312..204e8c9920 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -170,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 48b2922e77..c493547d9e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -172,11 +172,11 @@ tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { return InvalidArgument("Expected non-opaque argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 3efd38ce0d..f7a5512fec 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode()) { + if (HloOpcode::kDot != dot.opcode() || + dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { return {}; } @@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); + } else if (ShapeUtil::Rank(operand.shape()) != 2) { + return {}; } } @@ -74,23 +77,39 @@ using InstructionOperandsPair = // Folds the operands of `dot` that are foldable transposes. `computation` is // the parent HLO computation of `dot`. -// -// Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair) { - auto* dot = pair.first; - std::vector<HloInstruction*> instructions_to_fuse(1, dot); - for (const int64 operand_index : pair.second) { - instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); - } - - // Early-exit if no operands are foldable. - if (instructions_to_fuse.size() == 1) { - return false; +Status FoldTransposeIntoDot(InstructionOperandsPair pair) { + HloInstruction* dot = pair.first; + + DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* new_lhs = dot->mutable_operand(0); + HloInstruction* new_rhs = dot->mutable_operand(1); + + CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); + + for (int64 operand_index : pair.second) { + // We've checked that there aren't any batch dimensions and that the inputs + // are rank 2, and shape inference guarantees that there is exactly one + // contracting dimension. + if (operand_index == 0) { + CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_lhs_contracting_dimensions( + 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + new_lhs = new_lhs->mutable_operand(0); + } else { + CHECK_EQ(operand_index, 1); + CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + new_rhs = new_rhs->mutable_operand(0); + } } - dot->parent()->CreateFusionInstruction( - instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); - return true; + std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot( + dot->shape(), new_lhs, new_rhs, new_dim_numbers); + return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } // Folds the operands of `convolution` that are foldable transposes. @@ -205,7 +224,8 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) { bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair); + TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { changed |= FoldTransposeIntoConvolution(pair); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 0319109f7f..f73f1227aa 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -31,9 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase { }; TEST_F(TransposeFoldingTest, FoldDotTranspose) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3]{1,0} parameter(0) + y = f32[2,3]{1,0} parameter(1) + transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0} + ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); FoldTranspose(module.get()); - // Instructions after folding: x, y, and the fusion. - std::unordered_set<HloInstruction*> instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.size()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = *instruction_set.begin(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) { + string hlo_string = R"( +HloModule FoldDotTranspose - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +ENTRY entry_computation { + x = f32[2,3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { - auto builder = HloComputation::Builder("entry_computation"); - // 2x1 - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2<float>({{1}, {2}}))); - // 3x2 - HloInstruction* const1 = - builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}}))); - HloInstruction* transpose0 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1, 3}), - /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTransposeConstant + +ENTRY entry_computation { + constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} + constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} + ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); FoldTranspose(module.get()); - for (auto* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK_EQ(2, instruction->operand_count()); - EXPECT_EQ(const0, instruction->operand(0)); - EXPECT_EQ(const1, instruction->operand(1)); - } - } - - // The created fusion instruction should contain two parameters, two - // transposes (one for each parameter) and one dot. - EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Constant(), op::Constant(), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1)); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instruction_count()); } -TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - auto module = CreateNewModule("test_module"); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build(dot)); +TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) { + string hlo_string = R"( +HloModule FoldDotTransposeInCall - HloInstruction* call = module->OutlineExpressionFromComputation( - {transpose_y, dot}, "outlined", entry_computation); +callee { + name.0 = f32[2,3]{1,0} parameter(0) + name.1 = f32[2,3]{1,0} parameter(1) + transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0} + ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +ENTRY entry_computation { + y = f32[2,3]{1,0} parameter(1) + x = f32[2,3]{1,0} parameter(0) + ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + tools::Parse(hlo_string)); FoldTranspose(module.get()); - // Instructions after folding: x, y, and the fusion. - std::unordered_set<HloInstruction*> instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(call)) - << "call is not in entry_computation."; - CHECK(instruction_set.empty()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = - call->called_computations().front()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); - - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); + const HloComputation* callee = module->GetComputationWithName("callee"); + ASSERT_NE(callee, nullptr); + EXPECT_THAT(callee->root_instruction(), + op::Dot(op::Parameter(1), op::Parameter(0), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } // Test that a two dimension swap of the kernel gets folded into convolution. diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 5b44c26b7c..4f64fe8f83 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" namespace xla { diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc8..0e1387c939 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr<T> is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr<T> does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr<T*> -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr<T> is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr<T> does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr<T> is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc..7d76370e85 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same<StatusOr<char>::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr<int*> null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr<NoDefaultConstructor> statusor(tensorflow::errors::Cancelled("")); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0571ff5055..b982cf0dbc 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1867,7 +1867,10 @@ xla_test( xla_test( name = "local_client_execute_test", + # TODO(b/79375911): Test times out in LLVM at normal size. + size = "large", srcs = ["local_client_execute_test.cc"], + shard_count = 30, tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index c09e7eaf2b..41f9a5f666 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr<GlobalData> +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr<Literal> converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr<GlobalData> data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e58979a303..16e838e60f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -616,35 +616,6 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -std::unique_ptr<GlobalData> -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - XlaBuilder* builder, - XlaOp* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr<GlobalData> -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, XlaBuilder* builder, - XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr<Literal> converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr<GlobalData> data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b3efba4f8..efa5aed2d1 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -798,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } +TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{105.0}, {105.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{105.0, 105.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array( + new Array2D<float>({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{129.0}, {129.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D<float> expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr<Array2D<float>> constant_rhs_array( + new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1<int32>({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D<float> expected({{168.0}, {168.0}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3a945fb3b1..156a06c596 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -439,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<OpMetadata> metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional<string> backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1093,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1111,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1488,11 +1495,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1680,7 +1686,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1848,7 +1854,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair<string, AttrConfig>& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 4e085bc89c..e100d8cda1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1138,7 +1152,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 83f3bafc42..8064a967cd 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], @@ -44,6 +45,19 @@ py_test( ) py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 0000000000..230e4cc0f3 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,431 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [] + self.current_leaves.append(test) + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + self.current_leaves.extend(self.break_.pop()) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + gen = activity.get_read(node.value, self.context) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 0000000000..af7eaf30e8 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,252 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + # TODO(alexbw): b/73926938 split this test up + def test_live(self): + + def get_live_annotated_fnbody(f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + def f6(x): + return x # should this cause x.* to be live? + + def f7(x, n): + for i in range(n): + x += i + return x + + def f8(x, f): + with f: + x += 1 + + body = get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + body = get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + body = get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + body = get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + body = get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 9d6cc9245a..f06b73c00d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -501,11 +501,18 @@ def sparse_make_stats_update( example_partition_ids) # Compute aggregate stats for each partition. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 1b184d296b..50cc00afdc 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -187,7 +187,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: - A list of quantiles or approximate boundaries. + The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index d2c30f1215..e529b25b3c 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -19,6 +19,7 @@ For creating and managing dependencies: @@CheckpointableObjectGraph @@dot_graph_from_checkpoint @@object_metadata +@@NoDependency @@split_dependency """ @@ -29,6 +30,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable import NoDependency from tensorflow.python.training.checkpointable_utils import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 6588fd04ac..2568b899d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest( class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest( self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 42ec2b0b01..b9393de4e9 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 946310aa6f..45d191127e 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -265,6 +265,10 @@ class NamedDistribution(object): one_device_strategy = NamedDistribution( "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), None) +tpu_strategy_single_iteration = NamedDistribution( + "TPUSingleIteration", + tpu_strategy.TPUStrategy(iterations_per_step=1), + required_tpu=True) tpu_strategy = NamedDistribution( "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) mirrored_strategy_with_gpu_and_cpu = NamedDistribution( diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index b87224251c..2b05884b9b 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An example tf.keras model that is trained using MirroredStrategy.""" +"""An example of training tf.keras Model using MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from sys import argv + +import sys + import numpy as np import tensorflow as tf @@ -33,30 +35,37 @@ def input_fn(): def main(args): if len(args) < 2: - print('You must specify model_dir for checkpoints such as' - ' /tmp/tfkeras_example./') + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example/.') return - print('Using %s to store checkpoints.' % args[1]) - - strategy = tf.contrib.distribute.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) - optimizer = tf.train.GradientDescentOptimizer(0.2) + model_dir = args[1] + print('Using %s to store checkpoints.' % model_dir) + # Define tf.keras Model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + # Compile tf.keras Model. + optimizer = tf.train.GradientDescentOptimizer(0.2) model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() tf.keras.backend.set_learning_phase(True) + + # Define a DistributionStrategy and convert the tf.keras Model to a + # tf.Estimator that utilizes the DistributionStrategy. + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( - keras_model=model, config=config, model_dir=args[1]) + keras_model=model, config=config, model_dir=model_dir) + # Train and evaluate the tf.Estimator. keras_estimator.train(input_fn=input_fn, steps=10) eval_result = keras_estimator.evaluate(input_fn=input_fn) print('Eval result: {}'.format(eval_result)) + if __name__ == '__main__': - tf.app.run(argv=argv) + tf.app.run(argv=sys.argv) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index e134fe34e1..d2054715f1 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -44,13 +44,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True]), - combinations.combine(is_tpu=[False])) + - combinations.combine( - distribution=[combinations.tpu_strategy], - optimizer_fn=[combinations.adam_optimizer_v1_fn], - mode=["graph"], - use_callable_loss=[False], - is_tpu=[True])) + combinations.combine(is_tpu=[False])) + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + # TODO(isaprykin): Make Adam v2 work with while_loops + # and TPUs. + ], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): with distribution.scope(): @@ -101,7 +104,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn, ], mode=["graph"], is_tpu=[True])) @@ -171,13 +175,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): set(created_variables)) @combinations.generate( - combinations.times(combinations.distributions_and_v1_optimizers(), - combinations.combine( - mode=["graph", "eager"], - momentum=[0.8, 0.9, 0.99], - renorm=[False, True]))) + combinations.times( + combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + is_tpu=[False], + # TODO(isaprykin): Allow False here. Currently subsequent + # towers will re-execute UPDATE_OPS of previous towers. + update_ops_in_cross_tower_mode=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + mode=["graph"], + is_tpu=[True], + update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm): + renorm, is_tpu, + update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) @@ -185,7 +204,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, - renorm=renorm) + renorm=renorm, + update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) # Disable prefetching since that makes the specific input on each device # to be non deterministic, and this test relies on specific input being @@ -196,16 +216,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): dataset_fn).make_one_shot_iterator() def run_step(): - return control_flow_ops.group( - distribution.unwrap( - distribution.call_for_each_tower( - model_fn, - iterator.get_next(), - run_concurrently=batchnorm.built)) + - ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + fetches = distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), + run_concurrently=batchnorm.built)) + if update_ops_in_cross_tower_mode: + fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + return control_flow_ops.group(fetches) if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -229,22 +251,40 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + @combinations.generate( combinations.times( combinations.combine( - distribution=[combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn], - loss_reduction=[losses_impl.Reduction.SUM, - losses_impl.Reduction.MEAN, - losses_impl.Reduction.SUM_OVER_BATCH_SIZE, - losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), - combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + loss_reduction=[ + losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS + ]), + combinations.times( + combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + is_tpu=[False]), + combinations.combine( + mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + is_tpu=[True], + mode=["graph"], + use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss): + use_callable_loss, is_tpu): with distribution.scope(): all_vars = [] @@ -280,12 +320,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) run_step() - self.assertEqual(distribution.num_towers, len(all_vars)) v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) @@ -312,6 +353,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 6c5c055070..3635bd2e34 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -370,22 +370,27 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected_sum = 0.0 expected_mean = 0.0 for i, d in enumerate(dist.worker_devices): - # Test access within a device scope, should see different values. - with ops.device(d): - v_sum_value = self.evaluate(ret_v_sum.read_value()) - v_mean_value = self.evaluate(ret_v_mean.read_value()) - expected = i + 3.0 - self.assertEqual(expected, v_sum_value) - expected_sum += expected - expected = i * 6.0 - self.assertEqual(expected, v_mean_value) - expected_mean += expected - - # fetch() should return the value you get by applying the - # reduction across all towers. - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected expected_mean /= len(dist.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all towers (whether you use + # fetch(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + if not context.executing_eagerly(): + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 0db0b59fca..d1fdb3279c 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -22,6 +22,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.layers import normalization from tensorflow.python.ops import array_ops @@ -59,7 +60,7 @@ def minimize_loss_example(optimizer_fn, # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=2, drop_remainder=True)) + batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None @@ -68,11 +69,10 @@ def minimize_loss_example(optimizer_fn, layer = core.Dense(1, use_bias=use_bias) - def model_fn(xs): + def model_fn(x): """A very simple model written by the user.""" def loss_fn(): - x = math_ops.reduce_mean(xs, keepdims=True) y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y @@ -89,7 +89,8 @@ def minimize_loss_example(optimizer_fn, def batchnorm_example(optimizer_fn, batch_per_epoch=1, momentum=0.9, - renorm=False): + renorm=False, + update_ops_in_tower_mode=False): """Example of non-distribution-aware legacy code with batch normalization.""" def dataset_fn(): @@ -103,12 +104,19 @@ def batchnorm_example(optimizer_fn, optimizer = optimizer_fn() batchnorm = normalization.BatchNormalization( renorm=renorm, momentum=momentum, fused=False) + layer = core.Dense(1, use_bias=False) def model_fn(x): + """A model that uses batchnorm.""" def loss_fn(): - y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) - loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + y = batchnorm(x, training=True) + with ops.control_dependencies( + ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops_in_tower_mode else []): + loss = math_ops.reduce_mean( + math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) + # `x` and `y` will be fetched by the gradient computation, but not `loss`. return loss # Callable loss. diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a7e4fe80f3..75441786a6 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import nest -# TODO(isaprykin): Consider whether inheriting is really appropriate. class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" @@ -73,7 +72,6 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def infeed_input(i): """Get input, split it and then enqueue.""" iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] for inputs_per_core in iteration_inputs] for core_id in range(self._num_cores_per_host)] @@ -117,3 +115,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops) + + def _reduce(self, method_string, value, destinations): + del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + if method_string == 'mean': + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self._num_cores_per_host) + return tpu_ops.cross_replica_sum(value) + + @property + def num_towers(self): + return self._num_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index aaf177d07e..759f3c3599 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.training import checkpointable from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -60,7 +61,7 @@ class DistributedValues(object): else: device = distribute_lib.get_update_device() if device is None: - device = device_util.current() + return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] @@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op - def _as_graph_element(self): - # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): - return self._primary_var._as_graph_element() - return self.get()._as_graph_element() - def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored, def assign(self, *args, **kwargs): return self.get(device=_get_update_device()).assign(*args, **kwargs) + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice, super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): + _assert_tower_context() return self.get().assign(*args, **kwargs) @property def reduce_method(self): return self._reduce_method + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -672,11 +702,12 @@ class MultiWorkerDataset(object): return MultiWorkerDataIterator(iterators, self._worker_device_map) -class PerIteration(object): - """Holds input for multiple iterations at once.""" +class _PerKey(object): + """Holds data associated by keys.""" - def __init__(self, index): - self._index = index + def __init__(self, *index): + # pylint: disable=protected-access + self._index = list(index) def get(self, iteration): return array_ops.gather(self._index, iteration) @@ -687,6 +718,24 @@ class PerIteration(object): def get_dtype(self): return self._index[-1][-1].dtype + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + +class PerIteration(_PerKey): + """Holds input for multiple iterations at once.""" + + def __init__(self, *index): + # pylint: disable=protected-access + super(PerIteration, self).__init__(*[batch._index for batch in index]) + + +class Batches(_PerKey): + pass + class MultiIterator(object): """Iterator that returns results of multiple get_next()s.""" @@ -697,11 +746,31 @@ class MultiIterator(object): self._batches_per_iteration = batches_per_iteration def get_next(self, name=None): - return PerIteration([[ - self._dataset_iterator.get_next(name=name) - for _ in range(self._batches_per_iteration) - ] - for _ in range(self._iterations)]) + """Return PerIteration with `iterations x batches_per_iteration` inputs.""" + data = [] + for _ in range(self._batches_per_iteration): + batch = [] + for _ in range(self._iterations): + batch.append(self._dataset_iterator.get_next(name=name)) + data.append(batch) + + # Here is an example. Suppose each get_next returns a tuple of two tensors. + # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: + # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] + # + # After the first `map_structure` it gets transformed to: + # [(Batches(a, A), Batches(z, Z)), + # (Batches(b, B), Batches(y, Y)), + # (Batches(c, C), Batches(x, X))] + # + # After the second `map_structure` it gets transformed to a tuple of: + # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), + # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) + + data = nest.map_structure(Batches, *data) + data = nest.map_structure(PerIteration, *data) + + return data @property def initializer(self): diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index fad613155d..a1d56066b4 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -372,6 +372,7 @@ cuda_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], + shard_count = 4, ) cuda_py_test( @@ -459,7 +460,7 @@ cuda_py_test( cuda_py_test( name = "batch_reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/batch_reshape_test.py"], additional_deps = [ ":distributions_py", @@ -578,7 +579,7 @@ cuda_py_test( cuda_py_test( name = "wishart_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/wishart_test.py"], additional_deps = [ ":distributions_py", @@ -866,7 +867,7 @@ cuda_py_test( cuda_py_test( name = "batch_normalization_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], additional_deps = [ ":bijectors_py", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ca20442c39..dc45114b1c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase): -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py index 7435bcbc68..b003526392 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py @@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): return mu, sigma def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalFullCovariance( + loc=mu_a, + covariance_matrix=sigma_a, + validate_args=True) + mvn_b = ds.MultivariateNormalFullCovariance( + loc=mu_b, + covariance_matrix=sigma_b, + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index 685f32883d..b556d06123 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase): return mu, sigma def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) + batch_shape = [] + event_shape = [2] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl, kl_v) def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = ds.MultivariateNormalTriL( diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 85ad23e413..b158a51bb0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,10 +20,9 @@ from __future__ import print_function import itertools -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector @@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) -def _maybe_get_event_ndims_statically(event_ndims): - static_event_ndims = (event_ndims if isinstance(event_ndims, int) - else tensor_util.constant_value(event_ndims)) - if static_event_ndims is not None: - return static_event_ndims - - return event_ndims - - def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -238,13 +228,13 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant( - 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) if not self.bijectors: return ildj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -258,11 +248,12 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically( + event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -274,13 +265,12 @@ class Chain(bijector.Bijector): def _forward_log_det_jacobian(self, x, **kwargs): x = ops.convert_to_tensor(x, name="x") - fldj = constant_op.constant( - 0., dtype=x.dtype, name="inverse_log_det_jacobian") + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) if not self.bijectors: return fldj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -293,13 +283,21 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) x = b.forward(x, **kwargs.get(b.name, {})) return fldj + def _maybe_get_event_ndims_statically(self, event_ndims): + event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( + event_ndims) + if event_ndims_ is None: + return event_ndims + return event_ndims_ + + diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 8517a3bf7b..b8f352d5f5 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -36,9 +36,7 @@ def device_and_data_format(): 'channels_last') -def random_batch(batch_size, device_and_format=None): - _, data_format = device_and_format or device_and_data_format() - +def random_batch(batch_size, data_format): shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) shape = (batch_size,) + shape @@ -70,7 +68,7 @@ class ResNet50Test(tf.test.TestCase): if defun: model.call = tfe.defun(model.call) with tf.device(device), tfe.execution_mode(execution_mode): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) tfe.async_wait() self.assertEqual((2, 1000), output.shape) @@ -91,7 +89,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) @@ -101,7 +99,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) self.assertEqual((2, 2048), output.shape) @@ -115,7 +113,7 @@ class ResNet50Test(tf.test.TestCase): name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) tfe.async_wait() @@ -134,7 +132,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) gc.disable() # Warm up. Note that this first run does create significant amounts of # garbage to be collected. The hope is that this is a build-only effect, @@ -202,18 +200,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False, execution_mode=None, - device_and_format=None): + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): - images, _ = random_batch(batch_size, device_and_format) + images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: @@ -227,30 +225,34 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_apply_sync(self): - self._benchmark_eager_apply('eager_apply', defun=False) + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) def benchmark_eager_apply_async(self): self._benchmark_eager_apply( - 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) def _benchmark_eager_train(self, label, make_iterator, + device_and_format, defun=False, execution_mode=None, - device_and_format=None): + compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size, device_and_format) + (images, labels) = random_batch(batch_size, data_format) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -273,18 +275,21 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train_sync(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) def benchmark_eager_train_async(self): self._benchmark_eager_train( 'eager_train_async', MockIterator, + device_and_data_format(), defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) def benchmark_eager_train_datasets(self): @@ -294,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -304,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index e80ccbb74d..db50b33af2 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase): return math_ops.multiply(x, x) grad = tfe.gradients_function(square) - self.assertEquals([6], [x.numpy() for x in grad(3)]) + self.assertEquals([6], [x.numpy() for x in grad(3.)]) def testGradOfGrad(self): @@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase): grad = tfe.gradients_function(square) gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) - self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) + self.assertEquals([2], [x.numpy() for x in gradgrad(3.)]) def testCustomGrad(self): @@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn grad = tfe.gradients_function(f) - self.assertEquals([12], [x.numpy() for x in grad(3)]) + self.assertEquals([12], [x.numpy() for x in grad(3.)]) def testGPU(self): if tfe.num_gpus() <= 0: diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 571e2e3a5d..e9a68801ef 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -17,6 +17,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", ":linear", @@ -181,6 +182,43 @@ py_test( ) py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + +py_library( name = "head", srcs = [ "python/estimator/head.py", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d43b3ea6bf..ec502f86dd 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * @@ -56,6 +57,8 @@ _allowed_symbols = [ 'TowerOptimizer', 'RNNClassifier', 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000..e7e366a3f2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + ... + ``` + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dirs = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], + export_dirs[tf.estimator.ModeKeys.TRAIN]) + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000..89d02582e1 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,391 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 5d19bf4714..109fdd3883 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -560,10 +560,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, processed_labels=processed_labels) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): - """Returns an `EstimatorSpec`. + """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -586,7 +586,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -606,7 +606,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access classifier_output = head_lib._classification_output( # pylint:disable=protected-access scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -629,16 +629,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - probabilities=probabilities, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access + self._eval_metric_ops, { + 'labels': processed_labels, + 'probabilities': probabilities, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + })) # Train. if optimizer is not None: @@ -672,7 +674,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access summary.scalar( head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index d5b3b279a1..7355a403ae 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -381,7 +381,7 @@ py_test( py_test( name = "rev_block_lib_test", - size = "small", + size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 3b053cd4c6..4a360711f8 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -485,6 +485,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], deps = [ diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 3744abd860..dfc6a393d0 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -468,10 +468,15 @@ class Experiment(object): on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. - When `predicate_fn` is not specified, continuous eval will run in an - infinite loop (if `train_steps` is None). or exit once global step - reaches `train_steps`. - + Continuous eval behavior under different conditions: + * When `predicate_fn` is specified: + + if `train_steps` is None, run until `predicate_fn` returns False. + + if `train_steps` is specified, run until either global step + reaches `train_steps` or `predicate_fn` returns False. + * When `predicate_fn` is not specified: + + if `train_steps` is None, run in an infinite loop. + + if `train_steps` is specified, run until global step reaches + `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md new file mode 100644 index 0000000000..8fd63d5cee --- /dev/null +++ b/tensorflow/contrib/lite/RELEASE.md @@ -0,0 +1,8 @@ +# Release 0.1.7 + +* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit + fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). +* To reproduce the iOS library, it's required to cherry pick git commit + f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. +* The code is based on TensorFlow 1.8.0 release candidate and it's very close + to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 4910c89eae..35cf43dd32 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -162,6 +162,9 @@ typedef struct { } TfLitePadParams; typedef struct { +} TfLitePadV2Params; + +typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. int shape[8]; diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 962a7a8970..a038acf284 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -85,6 +85,11 @@ typedef enum { kTfLiteBuiltinMinimum = 57, kTfLiteBuiltinLess = 58, kTfLiteBuiltinNeg = 59, + kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, + kTfLiteBuiltinSelect = 64, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 0051ee84ec..f45fcceb2e 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -281,6 +281,32 @@ Options { } ``` +**GREATER** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than the corresponding element of the second tensor. +} +``` + +**GREATER_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than or equal to the corresponding element of the second tensor. +} +``` + **L2_NORMALIZATION** ``` @@ -325,6 +351,19 @@ Outputs { } ``` +**LESS_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is less + than or equal to the corresponding element of the second tensor. +} +``` + **LOCAL_RESPONSE_NORMALIZATION** ``` @@ -600,6 +639,20 @@ Outputs { } ``` +**SELECT** + +``` +Inputs { + 0: tensor + 1: tensor + 2: tensor +} +Outputs { + 0: tensor that contains the elementwise values of 'tensor 1' if the + corresponding value of 'tensor 0' is true or the value of 'tensor 2' if false. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263..0450e86ae7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template <class T> T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,6 +237,18 @@ class Interpreter { return nullptr; } + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template <class T> + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType<T>()) { + return reinterpret_cast<const T*>(tensor_ptr->data.raw); + } + } + return nullptr; + } + // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template <class T> @@ -238,13 +256,20 @@ class Interpreter { return typed_tensor<T>(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template <class T> T* typed_output_tensor(int index) { return typed_tensor<T>(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template <class T> + const T* typed_output_tensor(int index) const { + return typed_tensor<T>(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1dda55b8ed..1e57922603 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -46,12 +46,27 @@ android_library( ], ) -java_library( +android_library( name = "ovicbenchmarkerlib", srcs = [ "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + ":tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], javacopts = JAVACOPTS, visibility = ["//visibility:public"], deps = [ @@ -170,18 +185,14 @@ java_test( size = "medium", srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, test_class = "org.tensorflow.ovic.OvicClassifierTest", visibility = ["//visibility:public"], deps = [ - ":ovicbenchmarkerlib", + ":ovicbenchmarkerlib_java", "@com_google_truth", "@junit", ], diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml index 20f520814d..ef8a9e0845 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -13,51 +13,55 @@ See the License for the specific language governing permissions and limitations under the License. --> -<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" - xmlns:app="http://schemas.android.com/apk/res-auto" + +<LinearLayout + xmlns:android="http://schemas.android.com/apk/res/android" android:layout_width="match_parent" - android:layout_height="match_parent"> + android:layout_height="match_parent" + android:background="#bb7700" + android:orientation="horizontal"> + + <com.example.android.tflitecamerademo.AutoFitTextureView + android:id="@+id/texture" + android:layout_width="0dp" + android:layout_height="match_parent" + android:layout_weight=".8"/> + + <LinearLayout + android:layout_width="0dp" + android:layout_height="match_parent" + android:layout_weight=".2" + android:orientation="vertical"> + + <ImageView + android:id="@+id/logoview" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:scaleType="centerInside" + android:src="@drawable/logo"/> - <LinearLayout + <ToggleButton + android:id="@+id/button" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:textOff="@string/tflite" + android:textOn="@string/nnapi"/> + <NumberPicker + android:id="@+id/np" + android:layout_width="wrap_content" + android:layout_height="47dp" + android:layout_gravity="center_horizontal" + android:visibility="visible"/> + + <TextView + android:id="@+id/text" + android:textStyle="bold" android:layout_width="match_parent" android:layout_height="match_parent" - android:background="#bb7700" - android:orientation="horizontal" - android:weightSum="100"> - - <LinearLayout - android:layout_width="match_parent" - android:layout_height="match_parent" - android:layout_weight="30" - android:orientation="vertical"> - - <com.example.android.tflitecamerademo.AutoFitTextureView - android:id="@+id/texture" - android:layout_width="match_parent" - android:layout_height="match_parent" - android:layout_weight="100" /> - - <ImageView - android:id="@+id/logoview" - android:layout_width="match_parent" - android:layout_height="wrap_content" - android:layout_weight="100" - android:scaleType="centerCrop" - android:src="@drawable/logo" /> - - </LinearLayout> - - <TextView - android:id="@+id/text" - android:layout_width="match_parent" - android:layout_height="match_parent" - android:layout_weight="70" - android:paddingLeft="5dp" - android:paddingTop="20dp" - android:textColor="#FFF" - android:textSize="20sp" - android:textStyle="bold" /> - - </LinearLayout> - -</RelativeLayout> + android:paddingTop="20dp" + android:textColor="#FFF" + android:textSize="20sp"/> + + </LinearLayout> +</LinearLayout> + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index d12435d5ab..72a229ecdb 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -15,45 +15,47 @@ --> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:app="http://schemas.android.com/apk/res-auto" - xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" - android:layout_height="match_parent"> + android:layout_height="match_parent" + android:background="#bb7700"> - <LinearLayout + <com.example.android.tflitecamerademo.AutoFitTextureView + android:id="@+id/texture" android:layout_width="match_parent" android:layout_height="match_parent" - android:orientation="vertical" - android:weightSum="60"> - - <FrameLayout - android:id="@+id/control" - android:layout_width="match_parent" - android:layout_height="match_parent" - android:layout_alignParentBottom="true" - android:layout_alignParentStart="true" - android:layout_weight="60" - android:background="#cc7700" - android:paddingLeft="20dp" - android:paddingStart="20dp"> - - </FrameLayout> + android:layout_weight="1" /> - <com.example.android.tflitecamerademo.AutoFitTextureView - android:id="@+id/texture" + <LinearLayout android:layout_width="wrap_content" android:layout_height="wrap_content" + android:layout_alignParentBottom="true" + android:layout_alignParentEnd="false" android:layout_alignParentStart="true" - android:layout_alignParentLeft="true" - android:layout_alignParentTop="true" /> + android:layout_alignParentTop="false" + android:background="#bb7700" + android:orientation="vertical" + android:weightSum="100"> + + <ImageView + android:id="@+id/logoview2" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:layout_weight="30" + android:scaleType="fitStart" + android:src="@drawable/logo" /> <TextView android:id="@+id/text" android:layout_width="match_parent" - android:layout_height="match_parent" - android:layout_weight="20" + android:layout_height="wrap_content" + android:layout_alignParentBottom="true" + android:layout_alignParentEnd="true" + android:layout_alignParentRight="true" + android:layout_weight="30" android:textColor="#FFF" android:textSize="20sp" android:textStyle="bold" /> + </LinearLayout> <RelativeLayout @@ -83,33 +85,4 @@ android:layout_below="@+id/button" android:visibility="visible" /> </RelativeLayout> - - <RelativeLayout - android:id="@+id/control2" - android:layout_width="match_parent" - android:layout_height="135dp" - android:layout_alignParentLeft="true" - android:layout_alignParentStart="true" - android:layout_alignTop="@+id/control" - android:layout_marginLeft="300dp" - android:layout_marginStart="300dp" - android:background="@color/control_background"> - - <ToggleButton - android:id="@+id/button" - android:textOff="@string/tflite" - android:textOn="@string/nnapi" - android:layout_width="wrap_content" - android:layout_height="wrap_content" - android:layout_alignParentLeft="true" - android:layout_alignParentStart="true" /> - - <NumberPicker - android:id="@+id/np" - android:layout_width="wrap_content" - android:layout_height="wrap_content" - android:layout_below="@+id/button" - android:visibility="visible" /> - </RelativeLayout> - </RelativeLayout> diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 76c33838bf..77799b3569 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -6,7 +6,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. -## To test the benchmarker: +## Test the benchmarker: The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system. @@ -37,7 +37,7 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions @@ -56,28 +56,83 @@ cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. -* Add your model and test image to the BUILD rule: +* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`: ```JSON -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - "ovic/src/testdata/my_model.lite", # <--- Your submission. - "ovic/src/testdata/my_test_image.jpg", # <--- Your test image. - ], - ... +filegroup( + name = "ovic_testdata", + srcs = [ + "@tflite_ovic_testdata//:float_model.lite", + "@tflite_ovic_testdata//:low_res_model.lite", + "@tflite_ovic_testdata//:quantized_model.lite", + "@tflite_ovic_testdata//:test_image_128.jpg", + "@tflite_ovic_testdata//:test_image_224.jpg" + "my_model.lite", # <--- Your submission. + "my_test_image.jpg", # <--- Your test image. + ], + ... ``` * Modify `OvicClassifierTest.java` to test your model. -Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`. +Change `TEST_IMAGE_PATH` to `my_test_image.jpg`. Change either `FLOAT_MODEL_PATH` or `QUANTIZED_MODEL_PATH` to `my_model.lite` depending on whether your model runs inference in float or [8-bit](https://www.tensorflow.org/performance/quantization). Now you can run the bazel tests to catch any runtime issues with the submission. + +Note: Please make sure that your submission passes the test. If a submission fails to pass the test it will not be processed by the submission server. + +## Measure on-device latency + +We provide two ways to measure the on-device latency of your submission. The first is through our competition server, which is reliable and repeatable, but is limited to a few trials per day. The second is through the benchmarker Apk, which requires a device and may not be as accurate as the server, but has a fast turn-around and no access limitations. We recommend that the participants use the benchmarker apk for early development, and reserve the competition server for evaluating promising submissions. + +### Running the benchmarker app + +Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules. + +Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: + +* Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image. + +``` + private static final String TEST_IMAGE_PATH = "my_test_image.jpg"; + private static final String MODEL_PATH = "my_model.lite"; +``` + +* Adjust the benchmark parameters when needed: + +You can chnage the length of each experiment, and the processor affinity below. `BIG_CORE_MASK` is an integer whose binary encoding represents the set of used cores. This number is phone-specific. For example, Pixel 2 has 8 cores: the 4 little cores are represented by the 4 less significant bits, and the 4 big cores by the 4 more significant bits. Therefore a mask value of 16, or in binary `00010000`, represents using only the first big core. The mask 32, or in binary `00100000` uses the second big core and should deliver identical results as the mask 16 because the big cores are interchangeable. + +``` + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; +``` + +Note: You'll need ROOT access to the phone to change processor affinity. + +* Build and install the app. + +``` +bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary +adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk +``` + +Start the app and click the `Start` button in dark green. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example: + +``` +my_model.lite: Average latency=158.6ms after 20 runs. +``` + +### Sample latencies + +Note: the benchmarking results can be quite different depending on the background processes running on the phone. A few things that help stabilize the app's readings are placing the phone on a cooling plate, restarting the phone, and shutting down internet access. + +| Model | Pixel 1 latency (ms) | Pixel 2 latency (ms) | +| -------------------- |:---------------------:| --------------------:| +| float_model.lite | 120 | 155 | +| quantized_model.lite | 85 | 74 | +| low_res_model.lite | 4.2 | 4.0 | + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml new file mode 100644 index 0000000000..55f2961fd7 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml @@ -0,0 +1,48 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + Copyright 2018 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="ovic.demo.app" + android:versionCode="1" + android:versionName="1.0" > + + <uses-sdk + android:minSdkVersion="19" + android:targetSdkVersion="21" /> + + <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> + <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" /> + <uses-permission android:name="android.permission.READ_PHONE_STATE" /> + + <application + android:allowBackup="true" + android:icon="@drawable/ic_launcher" + android:largeHeap="true" + android:label="@string/app_name"> + <activity + android:name="ovic.demo.app.OvicBenchmarkerActivity" + android:label="@string/app_name" + android:screenOrientation="portrait"> + + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + +</manifest> diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD new file mode 100644 index 0000000000..47101ff574 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -0,0 +1,29 @@ +# Sample app for OVIC benchmarking. +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "ovic_benchmarker_binary", + srcs = [ + "OvicBenchmarker.java", + "OvicBenchmarkerActivity.java", + ], + assets = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + assets_dir = "", + custom_package = "ovic.demo.app", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".lite", + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:ovicbenchmarkerlib", + "//tensorflow/contrib/lite/java:tensorflowlite", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java index d0102883e6..113ab74a20 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java @@ -1,4 +1,4 @@ -/*Copyright 2018 Google LLC +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.ovic; +package ovic.demo.app; import android.graphics.Bitmap; import android.os.SystemClock; @@ -22,6 +22,8 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import org.tensorflow.ovic.OvicClassifier; +import org.tensorflow.ovic.OvicSingleImageResult; /** * Class that benchmarks image classifier models. diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java new file mode 100644 index 0000000000..59457c308a --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package ovic.demo.app; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.os.Bundle; +import android.os.Process; +import android.os.SystemClock; +import android.util.Log; +import android.view.View; +import android.widget.TextView; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.text.DecimalFormat; +import org.tensorflow.ovic.OvicSingleImageResult; + +/** Class that benchmark image classifier models. */ +public class OvicBenchmarkerActivity extends Activity { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarkerActivity"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; + private static final String MODEL_PATH = "float_model.lite"; + /** + * Each bottom press will launch a benchmarking experiment. The experiment stops when either the + * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, + * whichever comes first. + */ + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; + /** Amount of time in milliseconds to wait for affinity to set. */ + private static final int WAIT_TIME_FOR_AFFINITY = 1000; + + /* The model to be benchmarked. */ + private MappedByteBuffer model = null; + private InputStream labelInputStream = null; + private OvicBenchmarker benchmarker; + /** Inference result of each iteration. */ + OvicSingleImageResult iterResult = null; + + private TextView textView = null; + // private Button startButton = null; + private static final DecimalFormat df2 = new DecimalFormat(".##"); + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + // TextView used to display the progress, for information purposes only. + textView = (TextView) findViewById(R.id.textView); + } + + private Bitmap loadTestBitmap() throws IOException { + InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + return BitmapFactory.decodeStream(imageStream); + } + + public void initializeTest() throws IOException { + Log.i(TAG, "Initializing benchmarker."); + benchmarker = new OvicBenchmarker(WALL_TIME); + AssetManager am = getAssets(); + AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = modelInputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + labelInputStream = am.open(LABEL_PATH); + } + + public Boolean doTestIteration() throws IOException, InterruptedException { + if (benchmarker == null) { + throw new RuntimeException("Benchmarker has not been initialized."); + } + if (benchmarker.shouldStop()) { + return false; + } + if (!benchmarker.readyToTest()) { + Log.i(TAG, "getting ready to test."); + benchmarker.getReadyToTest(labelInputStream, model); + if (!benchmarker.readyToTest()) { + throw new RuntimeException("Failed to get the benchmarker ready."); + } + } + Log.i(TAG, "Going to do test iter."); + // Start testing. + Bitmap testImageBitmap = loadTestBitmap(); + iterResult = benchmarker.doTestIteration(testImageBitmap); + testImageBitmap.recycle(); + if (iterResult == null) { + throw new RuntimeException("Inference failed to produce a result."); + } + Log.i(TAG, iterResult.toString()); + return true; + } + + public void startPressed(View view) throws IOException { + Log.i(TAG, "Start pressed"); + try { + initializeTest(); + } catch (IOException e) { + Log.e(TAG, "Can't initialize benchmarker.", e); + throw e; + } + String displayText = ""; + try { + setProcessorAffinity(BIG_CORE_MASK); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + displayText = e.getMessage() + "\n"; + } + Log.i(TAG, "Successfully initialized benchmarker."); + int testIter = 0; + Boolean iterSuccess = false; + double totalLatency = 0.0f; + while (testIter < MAX_ITERATIONS) { + try { + iterSuccess = doTestIteration(); + } catch (IOException e) { + Log.e(TAG, "Error during iteration " + testIter); + throw e; + } catch (InterruptedException e) { + Log.e(TAG, "Interrupted at iteration " + testIter); + } + if (!iterSuccess) { + break; + } + testIter++; + totalLatency += (double) iterResult.latency; + } + ; + Log.i(TAG, "Benchmarking finished"); + + if (textView != null) { + if (testIter > 0) { + textView.setText( + displayText + + MODEL_PATH + + ": Average latency=" + + df2.format(totalLatency / testIter) + + "ms after " + + testIter + + " runs."); + } else { + textView.setText("Benchmarker failed to run on more than one images."); + } + } + } + + private static void setProcessorAffinity(int mask) throws IOException { + int myPid = Process.myPid(); + Log.i(TAG, String.format("Setting processor affinity to 0x%02x", mask)); + + String command = String.format("taskset -a -p %x %d", mask, myPid); + try { + Runtime.getRuntime().exec(command).waitFor(); + } catch (InterruptedException e) { + throw new IOException("Interrupted: " + e); + } + + // Make sure set took effect - try for a second to confirm the change took. If not then fail. + long startTimeMs = SystemClock.elapsedRealtime(); + while (true) { + int readBackMask = readCpusAllowedMask(); + if (readBackMask == mask) { + Log.i(TAG, String.format("Successfully set affinity to 0x%02x", mask)); + return; + } + if (SystemClock.elapsedRealtime() > startTimeMs + WAIT_TIME_FOR_AFFINITY) { + throw new IOException( + String.format( + "Core-binding failed: affinity set to 0x%02x but read back as 0x%02x\n" + + "please root device.", + mask, readBackMask)); + } + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore sleep interrupted, will sleep again and compare is final cross-check. + } + } + } + + public static int readCpusAllowedMask() throws IOException { + // Determine how many CPUs there are total + final String pathname = "/proc/self/status"; + final String resultPrefix = "Cpus_allowed:"; + File file = new File(pathname); + String line = "<NO LINE READ>"; + String allowedCPU = ""; + Integer allowedMask = null; + BufferedReader bufReader = null; + try { + bufReader = new BufferedReader(new FileReader(file)); + while ((line = bufReader.readLine()) != null) { + if (line.startsWith(resultPrefix)) { + allowedMask = Integer.valueOf(line.substring(resultPrefix.length()).trim(), 16); + allowedCPU = bufReader.readLine(); + break; + } + } + } catch (RuntimeException e) { + throw new IOException( + "Invalid number in " + pathname + " line: \"" + line + "\": " + e.getMessage()); + } finally { + if (bufReader != null) { + bufReader.close(); + } + } + if (allowedMask == null) { + throw new IOException(pathname + " missing " + resultPrefix + " line"); + } + Log.i(TAG, allowedCPU); + return allowedMask; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle new file mode 100644 index 0000000000..c5d19bad89 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.ovicbenchmarker" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "lite", "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png Binary files differnew file mode 100644 index 0000000000..715d1b6d69 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png Binary files differnew file mode 100644 index 0000000000..9beff0885f --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml new file mode 100644 index 0000000000..93f5c6a016 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml @@ -0,0 +1,39 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + Copyright 2018 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<selector + xmlns:android="http://schemas.android.com/apk/res/android"> + <item + android:state_enabled="false"> + <shape android:shape="rectangle"> + <solid android:color="#808080"/> + </shape> + </item> + <item + android:state_enabled="true" + android:state_pressed="true"> + <shape android:shape="rectangle"> + <solid android:color="#44ff44"/> + </shape> + </item> + <item + android:state_enabled="true" + android:state_pressed="false"> + <shape android:shape="rectangle" > + <solid android:color="#227f22"/> + </shape> + </item> +</selector> diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml new file mode 100644 index 0000000000..e9d83bae54 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -0,0 +1,54 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + Copyright 2018 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<RelativeLayout + xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:tools="http://schemas.android.com/tools" + android:layout_width="match_parent" + android:layout_height="match_parent" + android:paddingBottom="@dimen/activity_vertical_margin" + android:paddingLeft="@dimen/activity_horizontal_margin" + android:paddingRight="@dimen/activity_horizontal_margin" + android:paddingTop="@dimen/activity_vertical_margin" + tools:context="ovic.demo.app.OvicBenchmarkerActivity"> + + <TextView + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:text="@string/initial_status_msg" + android:id="@+id/textView" + android:layout_above="@+id/button_start" + android:layout_alignParentTop="true"/> + + <Button + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:text="@string/start_label" + android:id="@id/button_start" + android:layout_alignParentBottom="true" + android:layout_alignParentLeft="true" + android:background="@drawable/start_button_color" + android:padding="10dp" + android:layout_marginRight="30dp" + android:layout_marginLeft="100dp" + android:layout_marginTop="10dp" + android:foreground="#000000" + android:textColor="#ffffff" + android:enabled="true" + style="?android:attr/buttonBarButtonStyle" + android:onClick="startPressed"/> + +</RelativeLayout> diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml new file mode 100644 index 0000000000..250b581430 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/dimens.xml @@ -0,0 +1,20 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + Copyright 2018 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<resources> + <dimen name="activity_vertical_margin">20dp</dimen> + <dimen name="activity_horizontal_margin">16dp</dimen> +</resources> diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml new file mode 100644 index 0000000000..d26beb1d27 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml @@ -0,0 +1,22 @@ +<?xml version="1.0" encoding="utf-8"?> +<!-- + Copyright 2018 The Android Open Source Project + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<resources> + <string name="app_name" translatable="false">Benchmarker</string> + + <string name="start_label" translatable="false">Start</string> + <string name="initial_status_msg" translatable="false"> Press start to run the benchmarks.</string> +</resources> diff --git a/tensorflow/contrib/lite/java/ovic/demo/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/build.gradle new file mode 100644 index 0000000000..b78a0b86c9 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/build.gradle @@ -0,0 +1,23 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:2.3.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle.properties b/tensorflow/contrib/lite/java/ovic/demo/gradle.properties new file mode 100644 index 0000000000..aac7c9b461 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar Binary files differnew file mode 100644 index 0000000000..13372aef5e --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.jar diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..fa7a38a0e4 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Thu Sep 28 09:01:41 PDT 2017 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew b/tensorflow/contrib/lite/java/ovic/demo/gradlew new file mode 100755 index 0000000000..9d82f78915 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat b/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat new file mode 100644 index 0000000000..8a0b282aa6 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/tensorflow/contrib/lite/java/ovic/demo/settings.gradle b/tensorflow/contrib/lite/java/ovic/demo/settings.gradle new file mode 100644 index 0000000000..e7b4def49c --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java index b2dfd8f2e7..4cf51bb0fa 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java +++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java @@ -67,7 +67,7 @@ public class OvicClassifier { }); /** Initializes an {@code OvicClassifier}. */ - OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) + public OvicClassifier(InputStream labelInputStream, MappedByteBuffer model) throws IOException, RuntimeException { if (model == null) { throw new RuntimeException("Input model is empty."); @@ -106,7 +106,7 @@ public class OvicClassifier { /** Classifies a {@link ByteBuffer} image. */ // @throws RuntimeException if model is uninitialized. - OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) throws RuntimeException { + public OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) { if (tflite == null) { throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed."); } diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java index 098ed8ceba..56f3e7604a 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -45,17 +45,17 @@ public final class OvicClassifierTest { private ByteBuffer lowResTestImage = null; private OvicSingleImageResult testResult = null; private static final String LABELS_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; + "tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt"; private static final String QUANTIZED_MODEL_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/quantized_model.lite"; + "external/tflite_ovic_testdata/quantized_model.lite"; private static final String LOW_RES_MODEL_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/low_res_model.lite"; + "external/tflite_ovic_testdata/low_res_model.lite"; private static final String FLOAT_MODEL_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/float_model.lite"; + "external/tflite_ovic_testdata/float_model.lite"; private static final String TEST_IMAGE_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_224.jpg"; + "external/tflite_ovic_testdata/test_image_224.jpg"; private static final String TEST_LOW_RES_IMAGE_PATH = - "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_128.jpg"; + "external/tflite_ovic_testdata/test_image_128.jpg"; private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform" @Before diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD new file mode 100644 index 0000000000..1021ea30dd --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD @@ -0,0 +1,19 @@ +# Testdata for OVIC benchmarker demo App and tests. +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "ovic_testdata", + srcs = [ + "@tflite_ovic_testdata//:float_model.lite", + "@tflite_ovic_testdata//:low_res_model.lite", + "@tflite_ovic_testdata//:quantized_model.lite", + "@tflite_ovic_testdata//:test_image_128.jpg", + "@tflite_ovic_testdata//:test_image_224.jpg", + ], + visibility = ["//visibility:public"], +) + +exports_files( + ["labels.txt"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index feab18b5c2..79e3c9f266 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -164,6 +164,7 @@ cc_library( "register.cc", "reshape.cc", "resize_bilinear.cc", + "select.cc", "skip_gram.cc", "space_to_batch_nd.cc", "space_to_depth.cc", @@ -870,6 +871,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "select_test", + size = "small", + srcs = [ + "select_test.cc", + ], + tags = [ + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 87c413cb98..2885ce032b 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -28,7 +28,7 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; -TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -56,61 +56,139 @@ TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { +#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ + requires_broadcast \ + ? reference_ops::Broadcast##opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)) \ + : reference_ops::opname( \ + GetTensorData<type>(input1), GetTensorDims(input1), \ + GetTensorData<type>(input2), GetTensorDims(input2), \ + GetTensorData<bool>(output), GetTensorDims(output)); + +TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Greater, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Greater, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, GreaterEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} -#define TF_LITE_LESS(type, opname) \ - reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \ - GetTensorData<type>(input2), GetTensorDims(input2), \ - GetTensorData<bool>(output), GetTensorDims(output)); +TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Less, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Less, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; + } + return kTfLiteOk; +} +TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); // TODO(renjieliu): Support quantized data. - if (requires_broadcast) { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, BroadcastLess); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, BroadcastLess); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, BroadcastLess); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } - } else { - switch (input1->type) { - case kTfLiteFloat32: - TF_LITE_LESS(float, Less); - break; - case kTfLiteInt32: - TF_LITE_LESS(int32_t, Less); - break; - case kTfLiteInt64: - TF_LITE_LESS(int64_t, Less); - break; - default: - context->ReportError(context, - "Does not support type other than float|int"); - return kTfLiteError; - } + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, LessEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type other than float|int"); + return kTfLiteError; } -#undef TF_LITE_LESS return kTfLiteOk; } } // namespace comparisons +TfLiteRegistration* Register_GREATER() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEval}; + return &r; +} + +TfLiteRegistration* Register_GREATER_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::GreaterEqualEval}; + return &r; +} + TfLiteRegistration* Register_LESS() { - static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare, - comparisons::LessEval}; + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval}; + return &r; +} + +TfLiteRegistration* Register_LESS_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::LessEqualEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index da2d7f8589..835d238d36 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -23,6 +23,139 @@ namespace { using ::testing::ElementsAreArray; +class GreaterOpModel : public SingleOpModel { + public: + GreaterOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, GreaterFloat) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterInt) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, + false, true, false, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + +class GreaterEqualOpModel : public SingleOpModel { + public: + GreaterEqualOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, + BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, GreaterEqualFloat) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualInt) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualBroadcast) { + GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { + GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, + false, true, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + class LessOpModel : public SingleOpModel { public: LessOpModel(std::initializer_list<int> input1_shape, @@ -47,7 +180,7 @@ class LessOpModel : public SingleOpModel { int output_; }; -TEST(ArgMaxOpTest, LessFloat) { +TEST(ComparisonsTest, LessFloat) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); @@ -57,7 +190,7 @@ TEST(ArgMaxOpTest, LessFloat) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessInt) { +TEST(ComparisonsTest, LessInt) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5}); @@ -67,7 +200,7 @@ TEST(ArgMaxOpTest, LessInt) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessBroadcast) { +TEST(ComparisonsTest, LessBroadcast) { LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor<int>(model.input2(), {7}); @@ -77,7 +210,7 @@ TEST(ArgMaxOpTest, LessBroadcast) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } -TEST(ArgMaxOpTest, LessBroadcastTwoD) { +TEST(ComparisonsTest, LessBroadcastTwoD) { LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); @@ -88,6 +221,72 @@ TEST(ArgMaxOpTest, LessBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); } +class LessEqualOpModel : public SingleOpModel { + public: + LessEqualOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + TensorType input_type) { + input1_ = AddInput(input_type); + input2_ = AddInput(input_type); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(ComparisonsTest, LessEqualFloat) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualInt) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualBroadcast) { + LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor<int>(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(ComparisonsTest, LessEqualBroadcastTwoD) { + LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, + true, false, true, false})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index df29172f83..d8340d426a 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -5,6 +5,7 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") tflite_deps_intel = [ "@arm_neon_2_x86_sse", @@ -157,6 +158,7 @@ cc_library( ":quantization_util", ":strided_slice_logic", ":types", + ":reference_base", ":round", "//third_party/eigen3", "@gemmlowp", @@ -386,6 +388,9 @@ cc_library( ":armv7a": [ ":neon_tensor_utils", ], + ":haswell": [ + ":neon_tensor_utils", + ], ":ios_armv7": [ ":neon_tensor_utils", ], @@ -424,6 +429,7 @@ cc_test( "//conditions:default": [], }), linkstatic = 1, + tags = ["tflite_not_portable_ios"], deps = [ ":tensor_utils", "//tensorflow/contrib/lite:builtin_op_data", @@ -458,3 +464,5 @@ cc_test( ) exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 18601df22c..ede95dfee0 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -113,6 +113,20 @@ inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, right_shift); } +template <typename T> +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned<T>::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast<T>(1) + << (std::numeric_limits<T>::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index fd14cb23ea..580d208beb 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -31,6 +31,7 @@ limitations under the License. #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -38,6 +39,16 @@ limitations under the License. namespace tflite { namespace optimized_ops { +// Unoptimized reference ops: +using reference_ops::BroadcastGreater; +using reference_ops::BroadcastGreaterEqual; +using reference_ops::BroadcastLess; +using reference_ops::BroadcastLessEqual; +using reference_ops::Greater; +using reference_ops::GreaterEqual; +using reference_ops::Less; +using reference_ops::LessEqual; + // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to // construct the suitable Eigen type for the constness of the @@ -5851,10 +5862,26 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, } template <typename T> -inline void Pad(const T* input_data, const Dims<4>& input_dims, - const std::vector<int>& left_paddings, - const std::vector<int>& right_paddings, T* output_data, - const Dims<4>& output_dims, const int32_t pad_value) { +void TypedMemset(void* ptr, T value, size_t num) { + // Optimization for common cases where memset() will suffice. + if (value == 0 || std::is_same<T, uint8_t>::value) { + memset(ptr, value, num * sizeof(T)); + } else { + // Default implementation for cases where memset() will not preserve the + // bytes, e.g., typically when sizeof(T) > sizeof(uint8_t). + char* pos = static_cast<char*>(ptr); + for (size_t i = 0; i < num; ++i) { + memcpy(pos, &value, sizeof(T)); + pos = pos + sizeof(T); + } + } +} + +template <typename T> +inline void PadV2(const T* input_data, const Dims<4>& input_dims, + const std::vector<int>& left_paddings, + const std::vector<int>& right_paddings, T* output_data, + const Dims<4>& output_dims, const T pad_value) { gemmlowp::ScopedProfilingLabel label("Pad"); TFLITE_DCHECK_EQ(left_paddings.size(), 4); TFLITE_DCHECK_EQ(right_paddings.size(), 4); @@ -5877,27 +5904,28 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, const int input_depth = ArraySize(input_dims, 0); if (left_b_padding != 0) { - memset(output_data, pad_value, - left_b_padding * output_height * output_width * output_depth * - sizeof(T)); + TypedMemset<T>( + output_data, pad_value, + left_b_padding * output_height * output_width * output_depth); } for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; ++out_b) { if (left_h_padding != 0) { - memset(output_data + Offset(output_dims, 0, 0, 0, out_b), pad_value, - left_h_padding * output_width * output_depth * sizeof(T)); + TypedMemset<T>(output_data + Offset(output_dims, 0, 0, 0, out_b), + pad_value, left_h_padding * output_width * output_depth); } for (int out_h = left_h_padding; out_h < output_height - right_h_padding; ++out_h) { if (left_w_padding != 0) { - memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), pad_value, - left_w_padding * output_depth * sizeof(T)); + TypedMemset<T>(output_data + Offset(output_dims, 0, 0, out_h, out_b), + pad_value, left_w_padding * output_depth); } for (int out_w = left_w_padding; out_w < output_width - right_w_padding; ++out_w) { if (left_d_padding != 0) { - memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), - pad_value, left_d_padding * sizeof(T)); + TypedMemset<T>( + output_data + Offset(output_dims, 0, out_w, out_h, out_b), + pad_value, left_d_padding); } T* out = output_data + @@ -5908,35 +5936,46 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, memcpy(out, in, input_depth * sizeof(T)); if (right_d_padding != 0) { - memset( + TypedMemset<T>( output_data + Offset(output_dims, output_depth - right_d_padding, out_w, out_h, out_b), - pad_value, right_d_padding * sizeof(T)); + pad_value, right_d_padding); } } if (right_w_padding != 0) { - memset( + TypedMemset<T>( output_data + Offset(output_dims, 0, output_width - right_w_padding, out_h, out_b), - pad_value, right_w_padding * output_depth * sizeof(T)); + pad_value, right_w_padding * output_depth); } } if (right_h_padding != 0) { - memset(output_data + Offset(output_dims, 0, 0, - output_height - right_h_padding, out_b), - pad_value, - right_h_padding * output_width * output_depth * sizeof(T)); + TypedMemset<T>( + output_data + + Offset(output_dims, 0, 0, output_height - right_h_padding, out_b), + pad_value, right_h_padding * output_width * output_depth); } } if (right_b_padding != 0) { - memset(output_data + - Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), - 0, - right_b_padding * output_height * output_width * output_depth * - sizeof(T)); + TypedMemset<T>( + output_data + + Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + pad_value, + right_b_padding * output_height * output_width * output_depth); } } +// Legacy Pad() method that casts an int32_t to T before padding. +template <typename T> +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector<int>& left_paddings, + const std::vector<int>& right_paddings, T* output_data, + const Dims<4>& output_dims, const int32_t pad_value) { + const T converted_pad_value = static_cast<T>(pad_value); + PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data, + output_dims, converted_pad_value); +} + template <typename T> inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector<int>& left_paddings, @@ -6279,6 +6318,59 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +// UNOPTIMIZED COPY of Select from reference_ops.h. +template <typename D, typename T> +inline void Select(const D* input_condition_data, + const Dims<4>& input_condition_dims, const T* input_x_data, + const Dims<4>& input_x_dims, const T* input_y_data, + const Dims<4>& input_y_dims, T* output_data, + const Dims<4>& output_dims) { + const int64_t batches = + MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims, + 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims, + 2, output_dims, 2); + const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims, + 1, input_y_dims, 1, output_dims, 1); + const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims, + 0, input_y_dims, 0, output_dims, 0); + + const int64_t num_elements = batches * height * width * depth; + for (int64_t i = 0; i < num_elements; ++i) { + output_data[i] = + input_condition_data[i] ? input_x_data[i] : input_y_data[i]; + } +} + +// UNOPTIMIZED COPY of RankOneSelect from reference_ops.h. +template <typename D, typename T> +inline void RankOneSelect(const D* input_condition_data, + const Dims<4>& input_condition_dims, + const T* input_x_data, const Dims<4>& input_x_dims, + const T* input_y_data, const Dims<4>& input_y_dims, + T* output_data, const Dims<4>& output_dims) { + const int64_t rank = ArraySize(input_condition_dims, 0); + + const int64_t batches = + MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2); + const int64_t width = + MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1); + const int64_t depth = + MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(rank, batches); + + int64_t offset = 0; + int64_t size = depth * height * width; + for (int64_t i = 0; i < rank; i++) { + const T* input_data = input_condition_data[i] ? input_x_data : input_y_data; + memcpy(output_data + offset, input_data + offset, size * sizeof(T)); + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 3e9a3c29ee..2d74b3d384 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -167,6 +167,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) { EXPECT_EQ(qp.zero_point, 0); } +#ifdef GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) { // Assumption is that zero is within the range. EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), ""); @@ -176,6 +177,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) { // Assumption is that zero is within the range. EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), ""); } +#endif // GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) { QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0); @@ -189,6 +191,7 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { EXPECT_EQ(qp.zero_point, 255); } +#ifdef GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), ""); } @@ -261,6 +264,7 @@ TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31)); EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31)); } +#endif // GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, CalculateInputRadius) { EXPECT_EQ(CalculateInputRadius(4, 27), 15); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 798b55abc7..e2978cfd67 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -35,35 +35,6 @@ limitations under the License. namespace tflite { namespace reference_ops { -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); -} - -inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( - int32 x, int32 quantized_multiplier, int left_shift) { - using gemmlowp::SaturatingRoundingDoublingHighMul; - return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), - quantized_multiplier); -} - -template <typename T> -int CountLeadingZeros(T integer_input) { - static_assert(std::is_unsigned<T>::value, - "Only unsigned integer types handled."); - const T one_in_leading_positive = static_cast<T>(1) - << (std::numeric_limits<T>::digits - 1); - int leading_zeros = 0; - while (integer_input < one_in_leading_positive) { - integer_input <<= 1; - ++leading_zeros; - } - return leading_zeros; -} - // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -3158,10 +3129,10 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, } template <typename T> -inline void Pad(const T* input_data, const Dims<4>& input_dims, - const std::vector<int>& left_paddings, - const std::vector<int>& right_paddings, T* output_data, - const Dims<4>& output_dims, const int32_t pad_value) { +inline void PadV2(const T* input_data, const Dims<4>& input_dims, + const std::vector<int>& left_paddings, + const std::vector<int>& right_paddings, T* output_data, + const Dims<4>& output_dims, const T pad_value) { TFLITE_DCHECK_EQ(left_paddings.size(), 4); TFLITE_DCHECK_EQ(right_paddings.size(), 4); @@ -3194,7 +3165,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, out_w >= output_width - right_w_padding || out_d < left_d_padding || out_d >= output_depth - right_d_padding) { - *out_ptr++ = static_cast<T>(pad_value); + *out_ptr++ = pad_value; } else { *out_ptr++ = *in_ptr++; } @@ -3204,6 +3175,17 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Pad() method that casts an int32_t to T before padding. +template <typename T> +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector<int>& left_paddings, + const std::vector<int>& right_paddings, T* output_data, + const Dims<4>& output_dims, const int32_t pad_value) { + const T converted_pad_value = static_cast<T>(pad_value); + PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data, + output_dims, converted_pad_value); +} + template <typename T> inline void Pad(const T* input_data, const Dims<4>& input_dims, const std::vector<int>& left_paddings, @@ -3603,17 +3585,29 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } template <typename T> -inline void Less(int64_t num_elements, const T* input1, const T* input2, - bool* output) { - for (int64_t i = 0; i < num_elements; ++i) { - output[i] = input1[i] < input2[i]; - } +inline bool GreaterFn(T lhs, T rhs) { + return lhs > rhs; +} +template <typename T> +inline bool GreaterEqualFn(T lhs, T rhs) { + return lhs >= rhs; +} +template <typename T> +inline bool LessFn(T lhs, T rhs) { + return lhs < rhs; +} +template <typename T> +inline bool LessEqualFn(T lhs, T rhs) { + return lhs <= rhs; } template <typename T> -inline void Less(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims) { +using ComparisonFn = bool (*)(T, T); + +template <typename T, ComparisonFn<T> F> +inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims) { const int64_t batches = MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); const int64_t height = @@ -3622,31 +3616,201 @@ inline void Less(const T* input1_data, const Dims<4>& input1_dims, MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); const int64_t depth = MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - Less(batches * height * width * depth, input1_data, input2_data, output_data); + for (int64_t i = 0; i < batches * height * width * depth; ++i) { + output_data[i] = F(input1_data[i], input2_data[i]); + } } -template <typename T1, typename T2> -inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims, - T2* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastLess"); +template <typename T, ComparisonFn<T> F> +inline void Comparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, bool* output_data, + const Dims<4>& output_dims) { + const int64_t batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int64_t width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int64_t depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int64_t i = 0; i < batches * height * width * depth; ++i) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + output_data[i] = F(scaled_input1_val, scaled_input2_val); + } +} + +template <typename T, ComparisonFn<T> F> +inline void BroadcastComparison(const T* input1_data, + const Dims<4>& input1_dims, + const T* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims) { NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + F(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} +template <typename T, ComparisonFn<T> F> +inline void BroadcastComparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 input2_multiplier, int input2_shift, + bool* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); for (int b = 0; b < ArraySize(output_dims, 3); ++b) { for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] < - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + F(scaled_input1_val, scaled_input2_val); } } } } } +#define TFLITE_COMPARISON_OP(name) \ + template <typename T> \ + inline void name(const T* input1_data, const Dims<4>& input1_dims, \ + const T* input2_data, const Dims<4>& input2_dims, \ + bool* output_data, const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ + const Dims<4>& input2_dims, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } +TFLITE_COMPARISON_OP(Greater); +TFLITE_COMPARISON_OP(GreaterEqual); +TFLITE_COMPARISON_OP(Less); +TFLITE_COMPARISON_OP(LessEqual); +#undef TFLITE_COMPARISON_OP + +template <typename D, typename T> +inline void Select(const D* input_condition_data, + const Dims<4>& input_condition_dims, const T* input_x_data, + const Dims<4>& input_x_dims, const T* input_y_data, + const Dims<4>& input_y_dims, T* output_data, + const Dims<4>& output_dims) { + const int64_t batches = + MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims, + 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims, + 2, output_dims, 2); + const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims, + 1, input_y_dims, 1, output_dims, 1); + const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims, + 0, input_y_dims, 0, output_dims, 0); + + const int64_t num_elements = batches * height * width * depth; + for (int64_t i = 0; i < num_elements; ++i) { + output_data[i] = + input_condition_data[i] ? input_x_data[i] : input_y_data[i]; + } +} + +template <typename D, typename T> +inline void RankOneSelect(const D* input_condition_data, + const Dims<4>& input_condition_dims, + const T* input_x_data, const Dims<4>& input_x_dims, + const T* input_y_data, const Dims<4>& input_y_dims, + T* output_data, const Dims<4>& output_dims) { + const int64_t rank = ArraySize(input_condition_dims, 0); + + const int64_t batches = + MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3); + const int64_t height = + MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2); + const int64_t width = + MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1); + const int64_t depth = + MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(rank, batches); + + int64_t offset = 0; + int64_t size = depth * height * width; + for (int64_t i = 0; i < rank; i++) { + const T* input_data = input_condition_data[i] ? input_x_data : input_y_data; + memcpy(output_data + offset, input_data + offset, size * sizeof(T)); + offset += size; + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 4f9449a225..9e1e4658e9 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -37,9 +37,15 @@ struct PadContext { PadContext(TfLiteContext* context, TfLiteNode* node) { input = GetInput(context, node, 0); paddings = GetInput(context, node, 1); + if (NumInputs(node) == 3) { + constant_values = GetOptionalInputTensor(context, node, 2); + } else { + constant_values = nullptr; + } output = GetOutput(context, node, 0); dims = NumDimensions(input); } + TfLiteTensor* constant_values; TfLiteTensor* input; TfLiteTensor* paddings; TfLiteTensor* output; @@ -76,11 +82,15 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); PadContext op_context(context, node); TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + if (op_context.constant_values != nullptr) { + TF_LITE_ENSURE_EQ(context, op_context.input->type, + op_context.constant_values->type); + } // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, op_context.dims, 4); @@ -98,6 +108,11 @@ template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); + if (op_context.constant_values != nullptr) { + // Ensure that constant_values is a scalar. + TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1); + } + // Resize the output tensor if the output tensor is dynamic. if (IsDynamicTensor(op_context.output)) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); @@ -119,48 +134,70 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { after_padding.push_back(paddings_data[idx * 2 + 1]); } -#define TF_LITE_PAD(type, scalar, pad_value) \ - type::Pad(GetTensorData<scalar>(op_context.input), \ - GetTensorDims(op_context.input), before_padding, after_padding, \ - GetTensorData<scalar>(op_context.output), \ - GetTensorDims(op_context.output), pad_value) +#define TF_LITE_PAD(type, scalar, pad_value) \ + type::PadV2(GetTensorData<scalar>(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData<scalar>(op_context.output), \ + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { - case kTfLiteFloat32: + case kTfLiteFloat32: { + float pad_value = op_context.constant_values == nullptr + ? 0.f + : *GetTensorData<float>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, float, 0); + TF_LITE_PAD(reference_ops, float, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, float, 0); + TF_LITE_PAD(optimized_ops, float, pad_value); + } + } break; + case kTfLiteUInt8: { + uint8_t pad_value; + if (op_context.constant_values == nullptr) { + // Quantized Pad requires that 0 is represented in the quantized + // range. + TF_LITE_ENSURE(context, op_context.output->params.zero_point >= + std::numeric_limits<uint8_t>::min()); + TF_LITE_ENSURE(context, op_context.output->params.zero_point <= + std::numeric_limits<uint8_t>::max()); + pad_value = static_cast<uint8_t>(op_context.output->params.zero_point); + } else { + // Quantized Pad requires that 'constant_values' is represented in the + // same quantized range as the input and output tensors. + TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, + op_context.constant_values->params.zero_point); + TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, + op_context.constant_values->params.scale); + pad_value = *GetTensorData<uint8_t>(op_context.constant_values); } - break; - case kTfLiteUInt8: - // Quantized Pad requires that 0 is represented in the quantized range. - TF_LITE_ENSURE(context, op_context.output->params.zero_point >= - std::numeric_limits<uint8_t>::min()); - TF_LITE_ENSURE(context, op_context.output->params.zero_point <= - std::numeric_limits<uint8_t>::max()); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, uint8_t, - op_context.output->params.zero_point); + TF_LITE_PAD(reference_ops, uint8_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, uint8_t, - op_context.output->params.zero_point); + TF_LITE_PAD(optimized_ops, uint8_t, pad_value); } - break; - case kTfLiteInt32: + } break; + case kTfLiteInt32: { + int32_t pad_value = + op_context.constant_values == nullptr + ? 0 + : *GetTensorData<int32_t>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, int32_t, 0); + TF_LITE_PAD(reference_ops, int32_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, int32_t, 0); + TF_LITE_PAD(optimized_ops, int32_t, pad_value); } - break; - case kTfLiteInt64: + } break; + case kTfLiteInt64: { + int64_t pad_value = + op_context.constant_values == nullptr + ? 0L + : *GetTensorData<int64_t>(op_context.constant_values); if (kernel_type == kReference) { - TF_LITE_PAD(reference_ops, int64_t, 0); + TF_LITE_PAD(reference_ops, int64_t, pad_value); } else if (kernel_type == kGenericOptimized) { - TF_LITE_PAD(optimized_ops, int64_t, 0); + TF_LITE_PAD(optimized_ops, int64_t, pad_value); } - break; + } break; default: context->ReportError(context, "Type is currently not supported by Pad."); return kTfLiteError; @@ -185,6 +222,21 @@ TfLiteRegistration* Register_PAD_GENERIC_OPT() { TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } +// Also register Pad as PadV2. +TfLiteRegistration* Register_PADV2_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval<pad::kReference>}; + return &r; +} + +TfLiteRegistration* Register_PADV2_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval<pad::kGenericOptimized>}; + return &r; +} + +TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); } + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc index c06237e572..f8b9064fbb 100644 --- a/tensorflow/contrib/lite/kernels/pad_test.cc +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -24,21 +24,26 @@ namespace { using ::testing::ElementsAreArray; using ::testing::Matcher; +template <typename T> class PadOpModel : public SingleOpModel { public: - void SetInput(std::initializer_list<float> data) { - PopulateTensor<float>(input_, data); + void SetInput(std::initializer_list<T> data) { + PopulateTensor<T>(input_, data); } void SetQuantizedInput(std::initializer_list<float> data) { QuantizeAndPopulate<uint8_t>(input_, data); } + void SetQuantizedPadValue(float data) { + QuantizeAndPopulate<uint8_t>(constant_values_, {data}); + } + void SetPaddings(std::initializer_list<int> paddings) { PopulateTensor<int>(paddings_, paddings); } - std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<T> GetOutput() { return ExtractVector<T>(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); } std::vector<float> GetDequantizedOutput() { @@ -50,6 +55,59 @@ class PadOpModel : public SingleOpModel { int input_; int output_; int paddings_; + int constant_values_; +}; + +namespace { + +// Returns the corresponding TensorType given the type T. +template <typename T> +TensorType GetTensorType() { + if (std::is_same<T, float>::value) return TensorType_FLOAT32; + if (std::is_same<T, int32_t>::value) return TensorType_INT32; + if (std::is_same<T, uint8_t>::value) return TensorType_UINT8; + return TensorType_MIN; // default value +} + +} // namespace + +// Tests case where paddings is a const tensor. Type T is the dtype. +template <typename T> +class PadV2OpConstModel : public PadOpModel<T> { + public: + PadV2OpConstModel(const TensorData& input, + std::initializer_list<int> paddings_shape, + std::initializer_list<int> paddings, T constant_values, + const TensorData& output) { + this->input_ = this->AddInput(input); + this->paddings_ = + this->AddConstInput(TensorType_INT32, paddings, paddings_shape); + this->constant_values_ = + this->AddConstInput(GetTensorType<T>(), {constant_values}, {1}); + + this->output_ = this->AddOutput(output); + + this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, + CreatePadV2Options(this->builder_).Union()); + this->BuildInterpreter({input.shape}); + } + + PadV2OpConstModel(const TensorData& input, + std::initializer_list<int> paddings_shape, + std::initializer_list<int> paddings, + const TensorData& constant_values, + const TensorData& output) { + this->input_ = this->AddInput(input); + this->paddings_ = + this->AddConstInput(TensorType_INT32, paddings, paddings_shape); + this->constant_values_ = this->AddInput(constant_values); + + this->output_ = this->AddOutput(output); + + this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, + CreatePadV2Options(this->builder_).Union()); + this->BuildInterpreter({input.shape}); + } }; // Tests case where paddings is a const tensor. @@ -58,7 +116,7 @@ class PadOpModel : public SingleOpModel { // PadOpDynamicModel m(input_shape, paddings_shape, paddings_data); // m.SetInput(input_data); // m.Invoke(); -class PadOpConstModel : public PadOpModel { +class PadOpConstModel : public PadOpModel<float> { public: PadOpConstModel(const TensorData& input, std::initializer_list<int> paddings_shape, @@ -66,6 +124,7 @@ class PadOpConstModel : public PadOpModel { const TensorData& output) { input_ = AddInput(input); paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + constant_values_ = AddNullInput(); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, @@ -75,19 +134,52 @@ class PadOpConstModel : public PadOpModel { }; // Test case where paddings is a non-const tensor. +template <typename T> +class PadV2OpDynamicModel : public PadOpModel<T> { + public: + PadV2OpDynamicModel(const TensorData& input, + std::initializer_list<int> paddings_shape, + T constant_values, const TensorData& output) { + this->input_ = this->AddInput(input); + this->paddings_ = this->AddInput(TensorType_INT32); + this->constant_values_ = + this->AddConstInput(GetTensorType<T>(), {constant_values}, {1}); + this->output_ = this->AddOutput(output); + + this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, + CreatePadV2Options(this->builder_).Union()); + this->BuildInterpreter({input.shape, paddings_shape}); + } + PadV2OpDynamicModel(const TensorData& input, + std::initializer_list<int> paddings_shape, + const TensorData& constant_values, + const TensorData& output) { + this->input_ = this->AddInput(input); + this->paddings_ = this->AddInput(TensorType_INT32); + this->constant_values_ = this->AddInput(constant_values); + this->output_ = this->AddOutput(output); + + this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options, + CreatePadV2Options(this->builder_).Union()); + this->BuildInterpreter({input.shape, paddings_shape}); + } +}; + +// Test case where paddings is a non-const tensor. // // Example usage is as follows: // PadOpDynamicModel m(input_shape, paddings_shape); // m.SetInput(input_data); // m.SetPaddings(paddings_data); // m.Invoke(); -class PadOpDynamicModel : public PadOpModel { +class PadOpDynamicModel : public PadOpModel<float> { public: PadOpDynamicModel(const TensorData& input, std::initializer_list<int> paddings_shape, const TensorData& output) { input_ = AddInput(input); paddings_ = AddInput(TensorType_INT32); + constant_values_ = AddNullInput(); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, @@ -237,6 +329,272 @@ TEST_F(QuantizedPadOpTest, AdvancedDynamicTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } +TEST(PadV2OpTest, TooManyDimensions) { + EXPECT_DEATH(PadV2OpConstModel<float>( + {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0, + {TensorType_FLOAT32}), + "dims != 4"); +} + +TEST(PadV2OpTest, UnequalDimensions) { + EXPECT_DEATH( + PadV2OpConstModel<float>({TensorType_FLOAT32, {1, 1, 2, 1}}, {3, 2}, + {1, 1, 2, 2, 3, 3}, 0.0, {TensorType_FLOAT32}), + "3 != 4"); +} + +TEST(PadV2OpTest, InvalidPadValue) { + EXPECT_DEATH(PadV2OpConstModel<float>({TensorType_FLOAT32, {1, 1, 2, 1}}, + {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}, 0.0, + {TensorType_FLOAT32}), + "Pad value has to be greater than equal to 0."); +} + +TEST(PadV2OpTest, SimpleConstTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 0.0, + {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, SimpleConstFloat32ValuedTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4, + 5, 5, 5, 5, 5})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, Simple4DConstFloat32ValuedTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2}, + {0, 1, 0, 0, 0, 0, 0, 1}, 5, {TensorType_FLOAT32}); + m.SetInput({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 5, 3, 5, 5, 5, 5, 5})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2})); +} + +TEST(PadV2OpTest, SimpleConstInt32ValuedTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<int32_t> m({TensorType_INT32, {1, 2, 2, 1}}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_INT32}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4, + 5, 5, 5, 5, 5})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, SimpleDynamicTest) { + PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 0.0, + {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, SimpleDynamicValuedTest) { + PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 5, + {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4}); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 1, 2, 5, 5, 3, 4, + 5, 5, 5, 5, 5})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadV2OpTest, AdvancedConstTest) { + PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, 0, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST(PadV2OpTest, AdvancedDynamicTest) { + PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, 0, + {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +class QuantizedPadV2OpTest : public ::testing::Test { + protected: + std::vector<Matcher<float>> DequantizedArrayNear( + const std::vector<float>& values, const float min, const float max) { + const float quantization_tolerance = (max - min) / 255.0; + return ArrayFloatNear(values, quantization_tolerance); + } +}; + +TEST_F(QuantizedPadV2OpTest, ZeroNotInQuantizationRange) { + // The test_util and actual quantization code currently ensure that the range + // must include zero, but if that ever changes, this test will catch it. + EXPECT_DEATH( + PadV2OpConstModel<float> m({TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, + {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, 0, + {TensorType_UINT8, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedPadV2OpTest, SimpleConstTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, + {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.SetQuantizedPadValue(0); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_F(QuantizedPadV2OpTest, SimpleDynamicTest) { + PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.SetQuantizedPadValue(0); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_F(QuantizedPadV2OpTest, AdvancedConstTest) { + PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, + {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.SetQuantizedPadValue(0); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST_F(QuantizedPadV2OpTest, AdvancedDynamicTest) { + PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.SetQuantizedPadValue(0); + m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST_F(QuantizedPadV2OpTest, SimpleConstValuedTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}, + {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.SetQuantizedPadValue(-0.5); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9, + 0.7, -0.5, -0.5, -0.5, -0.5, -0.5}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_F(QuantizedPadV2OpTest, SimpleDynamicValuedTest) { + PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); + m.SetQuantizedPadValue(-0.5); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9, + 0.7, -0.5, -0.5, -0.5, -0.5, -0.5}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_F(QuantizedPadV2OpTest, AdvancedConstValuedTest) { + PadV2OpConstModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}, + {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.SetQuantizedPadValue(-0.5); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1, + -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST_F(QuantizedPadV2OpTest, AdvancedDynamicValuedTest) { + PadV2OpDynamicModel<uint8_t> m({TensorType_UINT8, {1, 2, 3, 1}, -1.0, 1.0}, + {4, 2}, {TensorType_UINT8, {1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); + m.SetQuantizedPadValue(-0.5); + m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {-0.5, -0.8, 0.2, 0.9, -0.5, -0.5, -0.5, -0.5, 0.7, 0.1, + -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, + -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5}, + -1.0, 1.0))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 29ea718a96..5df35aac62 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -60,6 +60,7 @@ TfLiteRegistration* Register_LSTM(); TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_PAD(); +TfLiteRegistration* Register_PADV2(); TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); TfLiteRegistration* Register_SKIP_GRAM(); @@ -79,9 +80,13 @@ TfLiteRegistration* Register_PRELU(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_GREATER(); +TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); +TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SELECT(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -121,6 +126,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, Register_UNIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_PAD, Register_PAD()); + AddBuiltin(BuiltinOperator_PADV2, Register_PADV2()); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); @@ -142,9 +148,13 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); + AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); + AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL()); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); AddBuiltin(BuiltinOperator_NEG, Register_NEG()); + AddBuiltin(BuiltinOperator_SELECT, Register_SELECT()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc new file mode 100644 index 0000000000..029ad9a709 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace select { + +constexpr int kInputTensorCondition = 0; +constexpr int kInputTensorX = 1; +constexpr int kInputTensorY = 2; +constexpr int kOutputTensor = 0; + +TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input_condition = + GetInput(context, node, kInputTensorCondition); + TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Input must be bool. + TF_LITE_ENSURE(context, input_condition->type == kTfLiteBool); + + // Input tensors must have the same type and size + TF_LITE_ENSURE_EQ(context, input_x->type, input_y->type); + TF_LITE_ENSURE(context, HaveSameShapes(input_x, input_y)); + output->type = input_x->type; + + // Either the same shape, or input_condition must be Rank 1 and match over the + // first dimension. + bool same_shape = HaveSameShapes(input_condition, input_x); + if (!same_shape && NumDimensions(input_condition) == 1) { + same_shape = + SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0); + } + + TF_LITE_ENSURE(context, same_shape); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_x->dims); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input_condition = + GetInput(context, node, kInputTensorCondition); + TfLiteTensor* input_x = GetInput(context, node, kInputTensorX); + TfLiteTensor* input_y = GetInput(context, node, kInputTensorY); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + bool is_rank_one = !HaveSameShapes(input_condition, input_x); + +#define TF_LITE_SELECT(type, op) \ + reference_ops::op(GetTensorData<bool>(input_condition), \ + GetTensorDims(input_condition), \ + GetTensorData<type>(input_x), GetTensorDims(input_x), \ + GetTensorData<type>(input_y), GetTensorDims(input_y), \ + GetTensorData<type>(output), GetTensorDims(output)); + +#define TF_LITE_SWITCH(type, op) \ + switch (type) { \ + break; \ + case kTfLiteBool: \ + TF_LITE_SELECT(bool, op); \ + break; \ + case kTfLiteFloat32: \ + TF_LITE_SELECT(float, op); \ + break; \ + case kTfLiteUInt8: \ + TF_LITE_SELECT(uint8_t, op); \ + break; \ + case kTfLiteInt32: \ + TF_LITE_SELECT(int32_t, op); \ + break; \ + case kTfLiteInt64: \ + TF_LITE_SELECT(int64_t, op); \ + break; \ + default: \ + context->ReportError(context, \ + "Does not support type other than bool|float|int"); \ + return kTfLiteError; \ + } + + if (is_rank_one) { + TF_LITE_SWITCH(input_x->type, RankOneSelect); + } else { + TF_LITE_SWITCH(input_x->type, Select); + } + +#undef TF_LITE_SELECT +#undef TF_LITE_SWITCH + return kTfLiteOk; +} + +} // namespace select + +TfLiteRegistration* Register_SELECT() { + static TfLiteRegistration r = {nullptr, nullptr, select::SelectPrepare, + select::SelectEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc new file mode 100644 index 0000000000..cfe24a5fc9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/select_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class SelectOpModel : public SingleOpModel { + public: + SelectOpModel(std::initializer_list<int> input1_shape, + std::initializer_list<int> input2_shape, + std::initializer_list<int> input3_shape, + TensorType input_type) { + input1_ = AddInput(TensorType_BOOL); + input2_ = AddInput(input_type); + input3_ = AddInput(input_type); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_SELECT, BuiltinOptions_SelectOptions, + CreateSelectOptions(builder_).Union()); + BuildInterpreter({input1_shape, input2_shape, input3_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + template <typename T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_); + } + + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int input3_; + int output_; +}; + +TEST(SelectOpTest, SelectBool) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_BOOL); + + model.PopulateTensor<bool>(model.input1(), {true, false, true, false}); + model.PopulateTensor<bool>(model.input2(), {false, false, false, false}); + model.PopulateTensor<bool>(model.input3(), {true, true, true, true}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<bool>(), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(SelectOpTest, SelectFloat) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_FLOAT32); + + model.PopulateTensor<bool>(model.input1(), {true, false, true, false}); + model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.3, 0.4}); + model.PopulateTensor<float>(model.input3(), {0.5, 0.6, 0.7, 0.8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray({0.1, 0.6, 0.3, 0.8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(SelectOpTest, SelectUInt8) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_UINT8); + + model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); + model.PopulateTensor<uint8>(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor<uint8>(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<uint8>(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(SelectOpTest, SelectInt32) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_INT32); + + model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); + model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(SelectOpTest, RankOneSelectInt32) { + SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32); + + model.PopulateTensor<bool>(model.input1(), {false, true}); + model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 3, 4})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1})); +} + +TEST(SelectOpTest, RankZeroSelectInt32) { + SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32); + + model.PopulateTensor<bool>(model.input1(), {false}); + model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 0bb28b50b2..5a6c85e97e 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -22,23 +22,6 @@ namespace tflite { using ::testing::FloatNear; using ::testing::Matcher; -namespace { -template <typename T> -std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) { - // These are required by many quantized operations. - CHECK_LE(f_min, 0); - CHECK_GE(f_max, 0); - T q_min = std::numeric_limits<T>::min(); - T q_max = std::numeric_limits<T>::max(); - float range = q_max - q_min; - float scale = (f_max - f_min) / range; - int32_t zero_point = std::min( - q_max, - std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale)))); - return {scale, zero_point}; -} -} // namespace - std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, float max_abs_error) { std::vector<Matcher<float>> matchers; @@ -49,69 +32,8 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, return matchers; } -int SingleOpModel::AddTensor(TensorData t, std::initializer_list<int> data) { - int id = tensors_.size(); - - // This is slightly different depending on whether we are adding a - // quantized or a regular tensor. - bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); - - flatbuffers::Offset<QuantizationParameters> q_params = 0; - - if (is_quantized) { - if (t.min != 0 || t.max != 0) { - if (t.type == TensorType_UINT8) { - std::tie(t.scale, t.zero_point) = - QuantizationParams<uint8_t>(t.min, t.max); - } else if (t.type == TensorType_INT32) { - std::tie(t.scale, t.zero_point) = - QuantizationParams<int32_t>(t.min, t.max); - } else { - LOG(FATAL) << "No support for the requested quantized type"; - } - t.min = 0; - t.max = 0; - } - - q_params = CreateQuantizationParameters( - builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>({t.scale}), - builder_.CreateVector<int64_t>({t.zero_point})); - } - - int buffer_id = 0; - if (data.size()) { - // Initialize buffers list with empty buffer to allow for non-const tensors. - if (buffers_.empty()) { - buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({}))); - } - - // Add data as a Buffer to buffers list. - buffer_id = buffers_.size(); - auto data_buffer = - builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()), - sizeof(int) * data.size()); - buffers_.push_back(CreateBuffer(builder_, data_buffer)); - } - - tensors_.push_back(CreateTensor(builder_, builder_.CreateVector<int>(t.shape), - t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); - - tensor_data_[id] = t; - - return id; -} - int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); - inputs_.push_back(id); - return id; -} - -int SingleOpModel::AddConstInput(TensorType type, - std::initializer_list<int> data, - std::initializer_list<int> shape) { - int id = AddTensor(TensorData{type, shape}, data); + int id = AddTensor<float>(t, {}); inputs_.push_back(id); return id; } @@ -123,7 +45,7 @@ int SingleOpModel::AddNullInput() { } int SingleOpModel::AddOutput(const TensorData& t) { - int id = AddTensor(t, {}); + int id = AddTensor<float>(t, {}); outputs_.push_back(id); return id; } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 6fb6fe27eb..6a9fdf1112 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -116,9 +116,14 @@ class SingleOpModel { int AddInput(TensorType type) { return AddInput(TensorData{type}); } int AddInput(const TensorData& t); - // Add a Tensor containing const data and return the tensor id. - int AddConstInput(TensorType type, std::initializer_list<int> data, - std::initializer_list<int> shape); + // Templated version of AddConstInput(). + template <typename T> + int AddConstInput(TensorType type, std::initializer_list<T> data, + std::initializer_list<int> shape) { + int id = AddTensor(TensorData{type, shape}, data); + inputs_.push_back(id); + return id; + } // Add a null input tensor (optional input) and return kOptionalTensor. int AddNullInput(); @@ -224,7 +229,79 @@ class SingleOpModel { std::unique_ptr<OpResolver> resolver_; private: - int AddTensor(TensorData t, std::initializer_list<int> data); + // TODO(gavinbelson): sync this method with + // //tensorflow/contrib/lite/kernels/internal/quantization_util.h?l=31 + template <typename T> + std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) { + // These are required by many quantized operations. + CHECK_LE(f_min, 0); + CHECK_GE(f_max, 0); + T q_min = std::numeric_limits<T>::min(); + T q_max = std::numeric_limits<T>::max(); + float range = q_max - q_min; + float scale = (f_max - f_min) / range; + int32_t zero_point = std::min( + q_max, + std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale)))); + return {scale, zero_point}; + } + + template <typename T> + int AddTensor(TensorData t, std::initializer_list<T> data) { + int id = tensors_.size(); + + // This is slightly different depending on whether we are adding a + // quantized or a regular tensor. + bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); + + flatbuffers::Offset<QuantizationParameters> q_params = 0; + + if (is_quantized) { + if (t.min != 0 || t.max != 0) { + if (t.type == TensorType_UINT8) { + std::tie(t.scale, t.zero_point) = + QuantizationParams<uint8_t>(t.min, t.max); + } else if (t.type == TensorType_INT32) { + std::tie(t.scale, t.zero_point) = + QuantizationParams<int32_t>(t.min, t.max); + } else { + LOG(FATAL) << "No support for the requested quantized type"; + } + t.min = 0; + t.max = 0; + } + + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, + builder_.CreateVector<float>({t.scale}), + builder_.CreateVector<int64_t>({t.zero_point})); + } + + int buffer_id = 0; + if (data.size()) { + // Initialize buffers list with empty buffer to allow for non-const + // tensors. + if (buffers_.empty()) { + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({}))); + } + + // Add data as a Buffer to buffers list. + buffer_id = buffers_.size(); + auto data_buffer = + builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()), + sizeof(T) * data.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } + + tensors_.push_back(CreateTensor(builder_, + builder_.CreateVector<int>(t.shape), t.type, + /*buffer=*/buffer_id, + /*name=*/0, q_params)); + + tensor_data_[id] = t; + + return id; + } std::map<int, TensorData> tensor_data_; std::vector<int32_t> inputs_; diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 590f042e21..e89036ce73 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -569,6 +569,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_PAD: { break; } + case BuiltinOperator_PADV2: { + break; + } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD<TfLiteReshapeParams>(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -669,7 +672,11 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_LESS: { + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_SELECT: { break; } case BuiltinOperator_DELEGATE: { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 6eac18c4f5..eb451397bd 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -61,6 +61,10 @@ NNAPIAllocation::~NNAPIAllocation() { } NNAPIDelegate::~NNAPIDelegate() { + if (nn_compiled_model_) { + ANeuralNetworksCompilation_free(nn_compiled_model_); + nn_compiled_model_ = nullptr; + } if (nn_model_) { ANeuralNetworksModel_free(nn_model_); nn_model_ = nullptr; @@ -347,6 +351,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: case tflite::BuiltinOperator_PAD: + case tflite::BuiltinOperator_PADV2: case tflite::BuiltinOperator_RESIZE_BILINEAR: case tflite::BuiltinOperator_CALL: case tflite::BuiltinOperator_SKIP_GRAM: @@ -371,8 +376,12 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MAXIMUM: case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: + case tflite::BuiltinOperator_GREATER: + case tflite::BuiltinOperator_GREATER_EQUAL: case tflite::BuiltinOperator_LESS: + case tflite::BuiltinOperator_LESS_EQUAL: case tflite::BuiltinOperator_NEG: + case tflite::BuiltinOperator_SELECT: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 5d89f7be62..2f5c39e7d7 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -137,6 +137,11 @@ enum BuiltinOperator : byte { MINIMUM = 57, LESS = 58, NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, } // Options for the builtin operators. @@ -183,6 +188,11 @@ union BuiltinOptions { ArgMaxOptions, LessOptions, NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, } enum Padding : byte { SAME, VALID } @@ -316,6 +326,9 @@ table CallOptions { table PadOptions { } +table PadV2Options { +} + table ReshapeOptions { new_shape:[int]; } @@ -405,12 +418,24 @@ table ArgMaxOptions { output_type : TensorType; } +table GreaterOptions { +} + +table GreaterEqualOptions { +} + table LessOptions { } +table LessEqualOptions { +} + table NegOptions { } +table SelectOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index c172f77aa9..a2f0c8cdd2 100755..100644 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -88,6 +88,9 @@ struct CallOptionsT; struct PadOptions; struct PadOptionsT; +struct PadV2Options; +struct PadV2OptionsT; + struct ReshapeOptions; struct ReshapeOptionsT; @@ -151,12 +154,24 @@ struct MaximumMinimumOptionsT; struct ArgMaxOptions; struct ArgMaxOptionsT; +struct GreaterOptions; +struct GreaterOptionsT; + +struct GreaterEqualOptions; +struct GreaterEqualOptionsT; + struct LessOptions; struct LessOptionsT; +struct LessEqualOptions; +struct LessEqualOptionsT; + struct NegOptions; struct NegOptionsT; +struct SelectOptions; +struct SelectOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -276,11 +291,16 @@ enum BuiltinOperator { BuiltinOperator_MINIMUM = 57, BuiltinOperator_LESS = 58, BuiltinOperator_NEG = 59, + BuiltinOperator_PADV2 = 60, + BuiltinOperator_GREATER = 61, + BuiltinOperator_GREATER_EQUAL = 62, + BuiltinOperator_LESS_EQUAL = 63, + BuiltinOperator_SELECT = 64, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_NEG + BuiltinOperator_MAX = BuiltinOperator_SELECT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[59] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[64] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -340,7 +360,12 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[59] { BuiltinOperator_ARG_MAX, BuiltinOperator_MINIMUM, BuiltinOperator_LESS, - BuiltinOperator_NEG + BuiltinOperator_NEG, + BuiltinOperator_PADV2, + BuiltinOperator_GREATER, + BuiltinOperator_GREATER_EQUAL, + BuiltinOperator_LESS_EQUAL, + BuiltinOperator_SELECT }; return values; } @@ -407,6 +432,11 @@ inline const char **EnumNamesBuiltinOperator() { "MINIMUM", "LESS", "NEG", + "PADV2", + "GREATER", + "GREATER_EQUAL", + "LESS_EQUAL", + "SELECT", nullptr }; return names; @@ -461,11 +491,16 @@ enum BuiltinOptions { BuiltinOptions_ArgMaxOptions = 40, BuiltinOptions_LessOptions = 41, BuiltinOptions_NegOptions = 42, + BuiltinOptions_PadV2Options = 43, + BuiltinOptions_GreaterOptions = 44, + BuiltinOptions_GreaterEqualOptions = 45, + BuiltinOptions_LessEqualOptions = 46, + BuiltinOptions_SelectOptions = 47, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_NegOptions + BuiltinOptions_MAX = BuiltinOptions_SelectOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[43] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[48] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -509,7 +544,12 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[43] { BuiltinOptions_MaximumMinimumOptions, BuiltinOptions_ArgMaxOptions, BuiltinOptions_LessOptions, - BuiltinOptions_NegOptions + BuiltinOptions_NegOptions, + BuiltinOptions_PadV2Options, + BuiltinOptions_GreaterOptions, + BuiltinOptions_GreaterEqualOptions, + BuiltinOptions_LessEqualOptions, + BuiltinOptions_SelectOptions }; return values; } @@ -559,6 +599,11 @@ inline const char **EnumNamesBuiltinOptions() { "ArgMaxOptions", "LessOptions", "NegOptions", + "PadV2Options", + "GreaterOptions", + "GreaterEqualOptions", + "LessEqualOptions", + "SelectOptions", nullptr }; return names; @@ -741,6 +786,26 @@ template<> struct BuiltinOptionsTraits<NegOptions> { static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; }; +template<> struct BuiltinOptionsTraits<PadV2Options> { + static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options; +}; + +template<> struct BuiltinOptionsTraits<GreaterOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; +}; + +template<> struct BuiltinOptionsTraits<GreaterEqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; +}; + +template<> struct BuiltinOptionsTraits<LessEqualOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; +}; + +template<> struct BuiltinOptionsTraits<SelectOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1108,6 +1173,46 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_NegOptions ? reinterpret_cast<const NegOptionsT *>(value) : nullptr; } + PadV2OptionsT *AsPadV2Options() { + return type == BuiltinOptions_PadV2Options ? + reinterpret_cast<PadV2OptionsT *>(value) : nullptr; + } + const PadV2OptionsT *AsPadV2Options() const { + return type == BuiltinOptions_PadV2Options ? + reinterpret_cast<const PadV2OptionsT *>(value) : nullptr; + } + GreaterOptionsT *AsGreaterOptions() { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast<GreaterOptionsT *>(value) : nullptr; + } + const GreaterOptionsT *AsGreaterOptions() const { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast<const GreaterOptionsT *>(value) : nullptr; + } + GreaterEqualOptionsT *AsGreaterEqualOptions() { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast<GreaterEqualOptionsT *>(value) : nullptr; + } + const GreaterEqualOptionsT *AsGreaterEqualOptions() const { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast<const GreaterEqualOptionsT *>(value) : nullptr; + } + LessEqualOptionsT *AsLessEqualOptions() { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast<LessEqualOptionsT *>(value) : nullptr; + } + const LessEqualOptionsT *AsLessEqualOptions() const { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast<const LessEqualOptionsT *>(value) : nullptr; + } + SelectOptionsT *AsSelectOptions() { + return type == BuiltinOptions_SelectOptions ? + reinterpret_cast<SelectOptionsT *>(value) : nullptr; + } + const SelectOptionsT *AsSelectOptions() const { + return type == BuiltinOptions_SelectOptions ? + reinterpret_cast<const SelectOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -2873,6 +2978,46 @@ inline flatbuffers::Offset<PadOptions> CreatePadOptions( flatbuffers::Offset<PadOptions> CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct PadV2OptionsT : public flatbuffers::NativeTable { + typedef PadV2Options TableType; + PadV2OptionsT() { + } +}; + +struct PadV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PadV2OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PadV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PadV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<PadV2Options> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PadV2OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit PadV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PadV2OptionsBuilder &operator=(const PadV2OptionsBuilder &); + flatbuffers::Offset<PadV2Options> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<PadV2Options>(end); + return o; + } +}; + +inline flatbuffers::Offset<PadV2Options> CreatePadV2Options( + flatbuffers::FlatBufferBuilder &_fbb) { + PadV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<PadV2Options> CreatePadV2Options(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ReshapeOptionsT : public flatbuffers::NativeTable { typedef ReshapeOptions TableType; std::vector<int32_t> new_shape; @@ -3995,6 +4140,86 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions( flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct GreaterOptionsT : public flatbuffers::NativeTable { + typedef GreaterOptions TableType; + GreaterOptionsT() { + } +}; + +struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GreaterOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<GreaterOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GreaterOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GreaterOptionsBuilder &operator=(const GreaterOptionsBuilder &); + flatbuffers::Offset<GreaterOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GreaterOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GreaterOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GreaterEqualOptionsT : public flatbuffers::NativeTable { + typedef GreaterEqualOptions TableType; + GreaterEqualOptionsT() { + } +}; + +struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GreaterEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<GreaterEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit GreaterEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GreaterEqualOptionsBuilder &operator=(const GreaterEqualOptionsBuilder &); + flatbuffers::Offset<GreaterEqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<GreaterEqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + GreaterEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct LessOptionsT : public flatbuffers::NativeTable { typedef LessOptions TableType; LessOptionsT() { @@ -4035,6 +4260,46 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions( flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct LessEqualOptionsT : public flatbuffers::NativeTable { + typedef LessEqualOptions TableType; + LessEqualOptionsT() { + } +}; + +struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LessEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LessEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<LessEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LessEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LessEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LessEqualOptionsBuilder &operator=(const LessEqualOptionsBuilder &); + flatbuffers::Offset<LessEqualOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<LessEqualOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LessEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct NegOptionsT : public flatbuffers::NativeTable { typedef NegOptions TableType; NegOptionsT() { @@ -4075,6 +4340,46 @@ inline flatbuffers::Offset<NegOptions> CreateNegOptions( flatbuffers::Offset<NegOptions> CreateNegOptions(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SelectOptionsT : public flatbuffers::NativeTable { + typedef SelectOptions TableType; + SelectOptionsT() { + } +}; + +struct SelectOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SelectOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SelectOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SelectOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<SelectOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SelectOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit SelectOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SelectOptionsBuilder &operator=(const SelectOptionsBuilder &); + flatbuffers::Offset<SelectOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<SelectOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<SelectOptions> CreateSelectOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + SelectOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4318,6 +4623,21 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const NegOptions *builtin_options_as_NegOptions() const { return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast<const NegOptions *>(builtin_options()) : nullptr; } + const PadV2Options *builtin_options_as_PadV2Options() const { + return builtin_options_type() == BuiltinOptions_PadV2Options ? static_cast<const PadV2Options *>(builtin_options()) : nullptr; + } + const GreaterOptions *builtin_options_as_GreaterOptions() const { + return builtin_options_type() == BuiltinOptions_GreaterOptions ? static_cast<const GreaterOptions *>(builtin_options()) : nullptr; + } + const GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const { + return builtin_options_type() == BuiltinOptions_GreaterEqualOptions ? static_cast<const GreaterEqualOptions *>(builtin_options()) : nullptr; + } + const LessEqualOptions *builtin_options_as_LessEqualOptions() const { + return builtin_options_type() == BuiltinOptions_LessEqualOptions ? static_cast<const LessEqualOptions *>(builtin_options()) : nullptr; + } + const SelectOptions *builtin_options_as_SelectOptions() const { + return builtin_options_type() == BuiltinOptions_SelectOptions ? static_cast<const SelectOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -4512,6 +4832,26 @@ template<> inline const NegOptions *Operator::builtin_options_as<NegOptions>() c return builtin_options_as_NegOptions(); } +template<> inline const PadV2Options *Operator::builtin_options_as<PadV2Options>() const { + return builtin_options_as_PadV2Options(); +} + +template<> inline const GreaterOptions *Operator::builtin_options_as<GreaterOptions>() const { + return builtin_options_as_GreaterOptions(); +} + +template<> inline const GreaterEqualOptions *Operator::builtin_options_as<GreaterEqualOptions>() const { + return builtin_options_as_GreaterEqualOptions(); +} + +template<> inline const LessEqualOptions *Operator::builtin_options_as<LessEqualOptions>() const { + return builtin_options_as_LessEqualOptions(); +} + +template<> inline const SelectOptions *Operator::builtin_options_as<SelectOptions>() const { + return builtin_options_as_SelectOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5572,6 +5912,29 @@ inline flatbuffers::Offset<PadOptions> CreatePadOptions(flatbuffers::FlatBufferB _fbb); } +inline PadV2OptionsT *PadV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PadV2OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PadV2Options::UnPackTo(PadV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<PadV2Options> PadV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePadV2Options(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<PadV2Options> CreatePadV2Options(flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PadV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePadV2Options( + _fbb); +} + inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ReshapeOptionsT(); UnPackTo(_o, _resolver); @@ -6115,6 +6478,52 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB _output_type); } +inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GreaterOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GreaterOptions::UnPackTo(GreaterOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<GreaterOptions> GreaterOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<GreaterOptions> CreateGreaterOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterOptions( + _fbb); +} + +inline GreaterEqualOptionsT *GreaterEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GreaterEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GreaterEqualOptions::UnPackTo(GreaterEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<GreaterEqualOptions> GreaterEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<GreaterEqualOptions> CreateGreaterEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GreaterEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterEqualOptions( + _fbb); +} + inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LessOptionsT(); UnPackTo(_o, _resolver); @@ -6138,6 +6547,29 @@ inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBuffe _fbb); } +inline LessEqualOptionsT *LessEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LessEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LessEqualOptions::UnPackTo(LessEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<LessEqualOptions> LessEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLessEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<LessEqualOptions> CreateLessEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLessEqualOptions( + _fbb); +} + inline NegOptionsT *NegOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new NegOptionsT(); UnPackTo(_o, _resolver); @@ -6161,6 +6593,29 @@ inline flatbuffers::Offset<NegOptions> CreateNegOptions(flatbuffers::FlatBufferB _fbb); } +inline SelectOptionsT *SelectOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SelectOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SelectOptions::UnPackTo(SelectOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<SelectOptions> SelectOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSelectOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<SelectOptions> CreateSelectOptions(flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SelectOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSelectOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -6512,6 +6967,26 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const NegOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast<const PadV2Options *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast<const SelectOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -6698,6 +7173,26 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const NegOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast<const PadV2Options *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast<const SelectOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -6872,6 +7367,26 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const NegOptionsT *>(value); return CreateNegOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast<const PadV2OptionsT *>(value); + return CreatePadV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<const GreaterOptionsT *>(value); + return CreateGreaterOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<const GreaterEqualOptionsT *>(value); + return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<const LessEqualOptionsT *>(value); + return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast<const SelectOptionsT *>(value); + return CreateSelectOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7046,6 +7561,26 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new NegOptionsT(*reinterpret_cast<NegOptionsT *>(u.value)); break; } + case BuiltinOptions_PadV2Options: { + value = new PadV2OptionsT(*reinterpret_cast<PadV2OptionsT *>(u.value)); + break; + } + case BuiltinOptions_GreaterOptions: { + value = new GreaterOptionsT(*reinterpret_cast<GreaterOptionsT *>(u.value)); + break; + } + case BuiltinOptions_GreaterEqualOptions: { + value = new GreaterEqualOptionsT(*reinterpret_cast<GreaterEqualOptionsT *>(u.value)); + break; + } + case BuiltinOptions_LessEqualOptions: { + value = new LessEqualOptionsT(*reinterpret_cast<LessEqualOptionsT *>(u.value)); + break; + } + case BuiltinOptions_SelectOptions: { + value = new SelectOptionsT(*reinterpret_cast<SelectOptionsT *>(u.value)); + break; + } default: break; } @@ -7263,6 +7798,31 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast<PadV2OptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast<GreaterOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast<GreaterEqualOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast<LessEqualOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast<SelectOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 211de63d58..f89c0d28d3 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -33,9 +33,12 @@ gen_zipped_test_files( "fused_batch_norm.zip", "gather.zip", "global_batch_norm.zip", + "greater.zip", + "greater_equal.zip", "l2_pool.zip", "l2norm.zip", "less.zip", + "less_equal.zip", "local_response_norm.zip", "log_softmax.zip", "max_pool.zip", @@ -45,6 +48,7 @@ gen_zipped_test_files( "mul.zip", "neg.zip", "pad.zip", + "padv2.zip", "relu.zip", "relu1.zip", "relu6.zip", @@ -60,6 +64,7 @@ gen_zipped_test_files( "sub.zip", "topk.zip", "transpose.zip", + "where.zip", ], ) diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 926bb3f121..f7cc7da900 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1391,6 +1391,60 @@ def make_pad_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_padv2_tests(zip_path): + """Make a set of tests to do padv2.""" + + # TODO(nupurgarg): Add test for tf.uint8. + test_parameters = [ + { + "dtype": [tf.int32, tf.int64, tf.float32], + "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], + "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0], + [0, 0], [2, 3]]], + "constant_paddings": [True, False], + "constant_values": [0, 2], + }, + # Non-4D use case. + { + "dtype": [tf.int32, tf.int64, tf.float32], + "input_shape": [[1, 2], [0, 1, 2]], + "paddings": [[[0, 1], [2, 3]]], + "constant_paddings": [True, False], + "constant_values": [0, 2], + }, + ] + + def build_graph(parameters): + """Build a pad graph given `parameters`.""" + input_tensor = tf.placeholder( + dtype=parameters["dtype"], + name="input", + shape=parameters["input_shape"]) + + # Get paddings as either a placeholder or constants. + if parameters["constant_paddings"]: + paddings = parameters["paddings"] + input_tensors = [input_tensor] + else: + shape = [len(parameters["paddings"]), 2] + paddings = tf.placeholder(dtype=tf.int32, name="padding", shape=shape) + input_tensors = [input_tensor, paddings] + + out = tf.pad(input_tensor, paddings=paddings, + constant_values=parameters["constant_values"]) + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_paddings"]: + values.append(np.array(parameters["paddings"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_reshape_tests(zip_path): """Make a set of tests to do reshape.""" @@ -2001,6 +2055,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_greater_tests(zip_path): + """Make a set of tests to do greater.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the greater op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.greater(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_greater_equal_tests(zip_path): + """Make a set of tests to do greater_equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the greater_equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.greater_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_less_tests(zip_path): """Make a set of tests to do less.""" @@ -2035,6 +2157,40 @@ def make_less_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_less_equal_tests(zip_path): + """Make a set of tests to do less_equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the less_equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.less_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_floor_tests(zip_path): """Make a set of tests to do floor.""" @@ -2086,10 +2242,41 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_where_tests(zip_path): + """Make a set of tests to do where.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),], + }] + + def build_graph(parameters): + """Build the where op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_set"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input3", + shape=parameters["input_shape_set"][1]) + less = tf.less(input_value1, input_value2) + out = tf.where(less, input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_set"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_set"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None - def main(unused_args): global bin_path def mkdir_if_not_exist(x): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 0673a3bb46..49762bdfe7 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -54,9 +54,11 @@ std::map<string, string> kBrokenTests = { {R"(^\/div.*int32)", "68808744"}, {R"(^\/sub.*int32)", "68808744"}, - // Pad only supports 4D tensors. + // Pad and PadV2 only supports 4D tensors. {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", "70527055"}, + {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", + "70527055"}, // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, @@ -256,9 +258,12 @@ INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) INSTANTIATE_TESTS(gather) INSTANTIATE_TESTS(global_batch_norm) +INSTANTIATE_TESTS(greater) +INSTANTIATE_TESTS(greater_equal) INSTANTIATE_TESTS(l2_pool) INSTANTIATE_TESTS(l2norm) INSTANTIATE_TESTS(less) +INSTANTIATE_TESTS(less_equal) INSTANTIATE_TESTS(local_response_norm) INSTANTIATE_TESTS(log_softmax) INSTANTIATE_TESTS(max_pool) @@ -268,6 +273,7 @@ INSTANTIATE_TESTS(minimum) INSTANTIATE_TESTS(mul) INSTANTIATE_TESTS(neg) INSTANTIATE_TESTS(pad) +INSTANTIATE_TESTS(padv2) // INSTANTIATE_TESTS(prelu) INSTANTIATE_TESTS(relu) INSTANTIATE_TESTS(relu1) @@ -283,6 +289,7 @@ INSTANTIATE_TESTS(squeeze) INSTANTIATE_TESTS(strided_slice) INSTANTIATE_TESTS(sub) INSTANTIATE_TESTS(transpose) +INSTANTIATE_TESTS(where) } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index ce0a74724a..01ce0d9db2 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -280,6 +280,7 @@ cc_library( "graph_transformations/resolve_mean_attributes.cc", "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", + "graph_transformations/resolve_padv2_attributes.cc", "graph_transformations/resolve_reorder_axes.cc", "graph_transformations/resolve_reshape_attributes.cc", "graph_transformations/resolve_slice_attributes.cc", diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 99ccfaea64..f5157149af 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1492,6 +1492,37 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op, shape->add_dim()->set_size(2); } +void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("PadV2"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); + for (int i = 0; i < src_op.left_padding.size(); ++i) { + tensor->add_int_val(src_op.left_padding[i]); + tensor->add_int_val(src_op.right_padding[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.left_padding.size()); + shape->add_dim()->set_size(2); +} + void CreateSliceInput(const string& input_name, const std::vector<int>& values, GraphDef* tensorflow_graph) { auto* params_op = tensorflow_graph->add_node(); @@ -1643,6 +1674,19 @@ void ConvertTensorFlowMaximumOperator(const Model& model, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Select"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + *sub_op->add_input() = src_op.inputs[2]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { auto* topk_op = tensorflow_graph->add_node(); @@ -1671,6 +1715,19 @@ void ConvertRandomUniformOperator(const Model& model, (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2); } +void ConvertComparisonOperator(const Model& model, const Operator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* comparison_op = tensorflow_graph->add_node(); + comparison_op->set_op(op_name); + comparison_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *comparison_op->add_input() = src_op.inputs[0]; + *comparison_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*comparison_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1795,6 +1852,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kPad) { ConvertPadOperator(model, static_cast<const PadOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kPadV2) { + ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kStridedSlice) { ConvertStridedSliceOperator( model, static_cast<const StridedSliceOperator&>(src_op), @@ -1859,6 +1919,17 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast<const RandomUniformOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreater) { + ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLess) { + ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kSelect) { + ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 72ffd51db4..4e3ea72182 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -174,6 +174,7 @@ DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes) DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index c1cf79f626..6342cf3e8a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -152,6 +152,17 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // Yield on ExpandDim until it is converted to Reshape return false; } + case OperatorType::kSelect: { + // Select produces outputs with the same type as their 2nd input + CHECK_EQ(op->inputs.size(), 3); + const ArrayDataType data_type_x = + model->GetArray(op->inputs[1]).data_type; + const ArrayDataType data_type_y = + model->GetArray(op->inputs[2]).data_type; + CHECK(data_type_x == data_type_y); + SetDataTypeForAllOutputs(model, op, data_type_x); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4923f83d91..52b739c5e2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -499,8 +499,8 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; } -void ProcessSimpleOperator(Model* model, Operator* op) { - const auto& input_array = model->GetArray(op->inputs[0]); +void ProcessSimpleOperator(Model* model, Operator* op, int input_index) { + const auto& input_array = model->GetArray(op->inputs[input_index]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { return; @@ -529,6 +529,21 @@ void ProcessSimpleBinaryOperator(Model* model, Operator* op) { &output_array); } +void ProcessSelectOperator(Model* model, SelectOperator* op) { + // Yield until all input dims have been resolved. + for (const auto& input : op->inputs) { + const auto& input_array = model->GetArray(input); + if (!input_array.has_shape()) { + return; + } + } + + // Select's output matches the second and third output. + const auto& input1_array = model->GetArray(op->inputs[1]); + auto& output_array = model->GetArray(op->outputs[0]); + output_array.copy_shape(input1_array.shape()); +} + void ProcessAddNOperator(Model* model, Operator* op) { // Yield until all input dims have been resolved. // @@ -670,8 +685,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Negative axis means the count starts at the back of the dims(). - int axis = op->axis; - if (axis < 0) axis += first_input_array.shape().dims().size(); + if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -684,14 +698,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector<int>& input_dims = input_array.shape().dims(); - CHECK_LT(axis, input_dims.size()); - concat_size += input_dims[axis]; + CHECK_LT(op->axis, input_dims.size()); + concat_size += input_dims[op->axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(axis, output_shape.dimensions_count()); - output_dims[axis] = concat_size; + CHECK_LT(op->axis, output_shape.dimensions_count()); + output_dims[op->axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { @@ -1147,6 +1161,32 @@ void ProcessPadOperator(Model* model, PadOperator* op) { output_array.copy_shape(output_shape); } +void ProcessPadV2Operator(Model* model, PadV2Operator* op) { + CHECK_EQ(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = model->GetArray(op->inputs[0]); + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->left_padding.empty()) return; + CHECK_EQ(op->left_padding.size(), op->right_padding.size()); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector<int>& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->left_padding.size(), dims.size()); + + for (int i = 0; i < op->left_padding.size(); ++i) { + dims[i] += op->left_padding[i] + op->right_padding[i]; + } + + output_array.copy_shape(output_shape); +} + void ProcessRankOperator(Model* model, RankOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); @@ -1474,7 +1514,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: - ProcessSimpleOperator(model, op); + ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: ProcessGatherOperator(model, static_cast<GatherOperator*>(op)); @@ -1545,7 +1585,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; - + case OperatorType::kSelect: + ProcessSelectOperator(model, static_cast<SelectOperator*>(op)); + break; case OperatorType::kSlice: ProcessSliceOperator(model, static_cast<SliceOperator*>(op)); break; @@ -1629,6 +1671,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kPad: ProcessPadOperator(model, static_cast<PadOperator*>(op)); break; + case OperatorType::kPadV2: + ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op)); + break; case OperatorType::kStridedSlice: ProcessStridedSliceOperator(model, static_cast<StridedSliceOperator*>(op)); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 347302c7a5..a1ca7371c8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -48,13 +48,18 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kLogSoftmax || type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || + type == OperatorType::kPadV2 || type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || - type == OperatorType::kTranspose || type == OperatorType::kMean; + type == OperatorType::kTranspose || type == OperatorType::kMean || + type == OperatorType::kTensorFlowGreater || + type == OperatorType::kTensorFlowGreaterEqual || + type == OperatorType::kTensorFlowLess || + type == OperatorType::kTensorFlowLessEqual; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -256,8 +261,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput( IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); return true; } - if ((op.type == OperatorType::kLogistic) || - (op.type == OperatorType::kSoftmax)) { + if (op.type == OperatorType::kLogistic || op.type == OperatorType::kSoftmax) { // Logistic and Softmax have range: [0, 1]. // // For Logistic, 0.5 should be exactly representable, as implementations diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 2b3ee36ad1..8f2c1f8162 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->EraseArray(bn_op->inputs[1]); - model->EraseArray(bn_op->inputs[2]); - model->EraseArray(bn_op->inputs[3]); + DeleteArrayIfUsedOnce(bn_op->inputs[1], model); + DeleteArrayIfUsedOnce(bn_op->inputs[2], model); + DeleteArrayIfUsedOnce(bn_op->inputs[3], model); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc new file mode 100644 index 0000000000..ebb023e342 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) { + const auto pad_it = model->operators.begin() + op_index; + auto* pad_op = pad_it->get(); + if (pad_op->type != OperatorType::kPadV2) return false; + + auto* op = static_cast<PadV2Operator*>(pad_op); + if (!op->left_padding.empty()) return false; + + CHECK_EQ(op->inputs.size(), 3); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& array = model->GetArray(op->inputs[1]); + if (!array.has_shape()) return false; + + const std::vector<int>& dims = array.shape().dims(); + CHECK_EQ(dims.size(), 2); + + std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data; + + for (int i = 0; i < dims[0]; ++i) { + op->left_padding.push_back(buffer[i * 2]); + op->right_padding.push_back(buffer[i * 2 + 1]); + } + + // TODO(dkalenichenko): Delete the extra input? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 8efe6ab7b9..1eef173afe 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -925,6 +925,19 @@ void ConvertPadOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertPadV2Operator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "PadV2"); + CheckInputsCount(node, tf_import_flags, 3); + auto* op = new PadV2Operator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + void ConvertShapeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1331,6 +1344,19 @@ void ConvertUnsupportedOperator(const NodeDef& node, } } +void ConvertSelectOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, 3); + + auto* op = new SelectOperator; + for (const auto& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -2169,6 +2195,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertMergeOperator(node, tf_import_flags, model); } else if (node.op() == "Pad") { ConvertPadOperator(node, tf_import_flags, model); + } else if (node.op() == "PadV2") { + ConvertPadV2Operator(node, tf_import_flags, model); } else if (node.op() == "StridedSlice") { ConvertStridedSliceOperator(node, tf_import_flags, model); } else if (node.op() == "Shape") { @@ -2239,6 +2267,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertDynamicStitchOperator(node, tf_import_flags, model); } else if (node.op() == "RandomUniform") { ConvertRandomUniform(node, tf_import_flags, model); + } else if (node.op() == "Select") { + ConvertSelectOperator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 482cc71d8b..47f8db5978 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -82,6 +82,7 @@ enum class OperatorType { kStack, kBatchToSpaceND, kPad, + kPadV2, kStridedSlice, kSlice, kSqueeze, @@ -132,6 +133,7 @@ enum class OperatorType { // instead of being given as plain constant arrays. So we need to insert // special nodes in the graph to shuffle axes. kReorderAxes, + kSelect, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -825,6 +827,29 @@ struct PadOperator : Operator { std::vector<int> right_padding; }; +// PaddingV2 operator. Pads a tensor with the given constant value. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the padding array +// inputs[2]: required: the scalar constant_values +// +// This operation pads input according to the paddings and constant_values you +// specify. paddings is an integer tensor with shape [Dn, 2], where n is the +// rank of input. For each dimension D of input, paddings[D, 0] indicates how +// many padding values to add before the contents of input in that dimension, +// and paddings[D, 1] indicates how many padding values to add after the +// contents of input in that dimension. constant_values is a scalar tensor of +// the same type as input that indicates the value to use for padding input. +// +// TensorFlow equivalent: PadV2 +struct PadV2Operator : Operator { + PadV2Operator() : Operator(OperatorType::kPadV2) {} + + std::vector<int> left_padding; + std::vector<int> right_padding; +}; + // Strided slice operator. // // Inputs: @@ -1063,6 +1088,18 @@ struct NegOperator : Operator { NegOperator() : Operator(OperatorType::kNeg) {} }; +// Element-wise select operator choosing elements from inputs[1] or input[2] +// +// Inputs: +// inputs[0]: required: boolean mask per index +// inputs[1]: required: tensor of values if true +// inputs[2]: required: tensor of values if false +// +// TensorFlow equivalent: Select +struct SelectOperator : Operator { + SelectOperator() : Operator(OperatorType::kSelect) {} +}; + // Element-wise reciprocal-square-root (x^-0.5) operator. // // Inputs: diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index e18ae805c0..90e24aa104 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -465,6 +465,21 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions, TocoOperator* op) const override {} }; +class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options, + ::tflite::BuiltinOptions_PadV2Options> { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreatePadV2Options(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} +}; + class Reshape : public BuiltinOperator<TensorFlowReshapeOperator, ::tflite::ReshapeOptions, @@ -832,6 +847,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kMaxPool)); ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); + ops.emplace_back( + new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kTensorFlowReshape)); ops.emplace_back( @@ -898,9 +915,18 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { "MAXIMUM", OperatorType::kTensorFlowMaximum)); ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>( "MINIMUM", OperatorType::kTensorFlowMinimum)); + ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>( + "GREATER", OperatorType::kTensorFlowGreater)); + ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>( + "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>( "LESS", OperatorType::kTensorFlowLess)); + ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>( + "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg)); + ops.emplace_back( + new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect)); + return ops; } } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 2b6c32b07c..a4fff9974a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -116,6 +116,7 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kTensorFlowLess); CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg); + CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 6973b22c5a..58c99051bd 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -106,6 +106,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveSpaceToBatchNDAttributes); transformations->Add(new ResolveBatchToSpaceNDAttributes); transformations->Add(new ResolvePadAttributes); + transformations->Add(new ResolvePadV2Attributes); transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); transformations->Add(new ResolveMeanAttributes); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 86ee1f3761..1f56fe5c83 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) { for (auto& input : op->inputs) { if (input == array_name) { count++; + // Breaking here is important: some graphs have ops that use the + // same array as more than one of their inputs, and in that case + // we want it counted only once. + break; } } } @@ -352,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(Pad) + HANDLE_OPERATORTYPENAME_CASE(PadV2) HANDLE_OPERATORTYPENAME_CASE(StridedSlice) HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) @@ -386,6 +391,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) + HANDLE_OPERATORTYPENAME_CASE(Select) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -2092,6 +2098,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { return ArrayDataType::kInt32; case INT64: return ArrayDataType::kInt64; + case BOOL: + return ArrayDataType::kBool; default: return ArrayDataType::kNone; } diff --git a/tensorflow/contrib/lite/toco/types.proto b/tensorflow/contrib/lite/toco/types.proto index 03bd6150bc..421667a83c 100644 --- a/tensorflow/contrib/lite/toco/types.proto +++ b/tensorflow/contrib/lite/toco/types.proto @@ -37,4 +37,7 @@ enum IODataType { // Int16, quantized QUANTIZED_INT16 = 6; + + // Boolean + BOOL = 7; } diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index ea6032e588..4b7af18b33 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -396,14 +396,19 @@ class Pruning(object): self._block_pooling_function) with ops.name_scope(weights.op.name + '_pruning_ops'): - abs_weights = math_ops.abs( - array_ops.reshape(weights, [ - 1, - squeezed_weights.get_shape()[0], - squeezed_weights.get_shape()[1], 1 - ])) + abs_weights = math_ops.abs(squeezed_weights) + pool_window = [self._block_dim[0], self._block_dim[1]] - pooled_weights = nn_ops.pool( + pool_fn = pruning_utils.factorized_pool + + if not self._spec.use_tpu: + pool_fn = nn_ops.pool + abs_weights = array_ops.reshape( + abs_weights, + [1, abs_weights.get_shape()[0], + abs_weights.get_shape()[1], 1]) + + pooled_weights = pool_fn( abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, @@ -411,19 +416,18 @@ class Pruning(object): padding='SAME', name=weights.op.name + '_pooled') + if pooled_weights.get_shape().ndims != 2: + pooled_weights = array_ops.squeeze(pooled_weights) + smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) - - reshaped_mask = array_ops.reshape( - new_mask, - [pooled_weights.get_shape()[1], - pooled_weights.get_shape()[2]]) updated_mask = pruning_utils.kronecker_product( - reshaped_mask, array_ops.ones(self._block_dim)) + new_mask, array_ops.ones(self._block_dim)) sliced_mask = array_ops.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1]]) + return smoothed_threshold, array_ops.reshape(sliced_mask, array_ops.shape(weights)) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 56d3dcef20..ef6c6a3f5d 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs): return math_ops.div(cdf, math_ops.reduce_max(cdf)) +def factorized_pool(input_tensor, + window_shape, + pooling_type, + strides, + padding, + name=None): + """Performs m x n pooling through a combination of 1xm and 1xn pooling. + + Args: + input_tensor: Input tensor. Must be rank 2 + window_shape: Pooling window shape + pooling_type: Either 'MAX' or 'AVG' + strides: The stride of the pooling window + padding: 'SAME' or 'VALID'. + name: Name of the op + + Returns: + A rank 2 tensor containing the pooled output + + Raises: + ValueError: if the input tensor is not rank 2 + """ + if input_tensor.get_shape().ndims != 2: + raise ValueError('factorized_pool() accepts tensors of rank 2 only') + + [height, width] = input_tensor.get_shape() + with ops.name_scope(name, 'factorized_pool'): + input_tensor_aligned = array_ops.reshape( + input_tensor, [1, 1, height, width], + name=input_tensor.op.name + '_aligned') + + height_pooling = nn_ops.pool( + input_tensor_aligned, + window_shape=[1, window_shape[0]], + pooling_type=pooling_type, + strides=[1, strides[0]], + padding=padding) + swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2]) + + width_pooling = nn_ops.pool( + swap_height_width, + window_shape=[1, window_shape[1]], + pooling_type=pooling_type, + strides=[1, strides[1]], + padding=padding) + + return array_ops.squeeze( + array_ops.transpose(width_pooling, perm=[0, 1, 3, 2])) + + def determine_partitioned_axis(partitioned_variable): partitioned_axis = 0 concatenated_variable_shape = partitioned_variable.get_shape() diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 10e1dd0a8e..ccde5b4e8a 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -22,8 +22,10 @@ import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -31,6 +33,30 @@ from tensorflow.python.platform import test class PruningUtilsTest(test.TestCase): + def _compare_cdf(self, values): + abs_values = math_ops.abs(values) + max_value = math_ops.reduce_max(abs_values) + with self.test_session(): + variables.global_variables_initializer().run() + cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( + abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) + cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) + self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) + + def _compare_pooling_methods(self, weights, pooling_kwargs): + with self.test_session(): + variables.global_variables_initializer().run() + pooled_weights_tf = array_ops.squeeze( + nn_ops.pool( + array_ops.reshape( + weights, + [1, weights.get_shape()[0], + weights.get_shape()[1], 1]), **pooling_kwargs)) + pooled_weights_factorized_pool = pruning_utils.factorized_pool( + weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), + pooled_weights_factorized_pool.eval()) + def testHistogram(self): width = 10 height = 10 @@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase): self.assertAllEqual(len(norm_cdf_val), nbins) self.assertAllEqual(expected_cdf, norm_cdf_val) - def _compare_cdf(self, values): - abs_values = math_ops.abs(values) - max_value = math_ops.reduce_max(abs_values) - with self.test_session(): - variables.global_variables_initializer().run() - cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( - abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) - cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) - return cdf.eval(), cdf_from_histogram.eval() - def testCDFEquivalence2D(self): width = 100 height = 100 weights = variable_scope.get_variable("weights", shape=[width, height]) - cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) - self.assertAllEqual(cdf_val, cdf_from_histogram_val) + self._compare_cdf(weights) def testCDFEquivalence4D(self): weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) - cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) - self.assertAllEqual(cdf_val, cdf_from_histogram_val) + self._compare_cdf(weights) + + def testFactorizedAvgPool(self): + weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + pooling_kwargs = { + "window_shape": [2, 4], + "pooling_type": "AVG", + "strides": [2, 4], + "padding": "SAME" + } + self._compare_pooling_methods(weights, pooling_kwargs) + + def testFactorizedMaxPool(self): + weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + pooling_kwargs = { + "window_shape": [2, 4], + "pooling_type": "MAX", + "strides": [2, 4], + "padding": "SAME" + } + self._compare_pooling_methods(weights, pooling_kwargs) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 0bdf6f64c9..f84ff1bfe9 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -181,6 +181,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index 3bdf7c2f83..defed00537 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -64,6 +64,10 @@ REGISTER_OP("TPUReplicatedOutput") "Operator that connects the output of an N-way replicated TPU " "computation to N separate outputs."); +REGISTER_OP("TPUCompilationResult") + .Output("output: string") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("TPUReplicate") .Attr("computation: func") .Attr("num_replicas: int >= 1") diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index fcfbbe1a21..7ecb36852c 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -21,3 +21,13 @@ tf_proto_library( cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "compilation_result_proto", + srcs = [ + "compilation_result.proto", + ], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto new file mode 100644 index 0000000000..cf52897de3 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/compilation_result.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; +package tensorflow.tpu; + +import "tensorflow/core/lib/core/error_codes.proto"; + +// Describes the result of a TPU compilation. +message CompilationResultProto { + // The error message, if any, returned during compilation. + error.Code status_code = 1; + string status_error_message = 2; +} diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 3455e0b4a6..faf677a81d 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -28,6 +28,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -78,6 +79,15 @@ class WorkerHeartbeatManager(object): return WorkerHeartbeatManager(session, devices, heartbeat_ops, request_placeholder) + def heartbeat_supported(self): + """Returns True if heartbeat operations are supported on all workers.""" + try: + # Send ping to verify worker has heartbeat support. + self.ping() + return True + except errors.InvalidArgumentError as _: + return False + def configure(self, message): """Configure heartbeat manager for all devices. @@ -106,7 +116,7 @@ class WorkerHeartbeatManager(object): event_pb2.WorkerHeartbeatResponse.FromString(res_pb) for res_pb in results ] - logging.info('Results: %s', parsed_results) + logging.debug('Ping results: %s', parsed_results) return parsed_results def lame_workers(self): @@ -189,7 +199,9 @@ class WatchdogManager(threading.Thread): self._running = False self._graph = ops.Graph() self._session = session_lib.Session( - target=session.sess_str, graph=self._graph) + target=session.sess_str, + graph=self._graph, + ) with self._graph.as_default(): if devices is None: @@ -249,6 +261,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): self._graph = ops.Graph() self._workers = None self._session = None + self._heartbeat_supported = False def after_create_session(self, training_session, coord): # pylint: disable=unused-argument # N.B. We have to pull the global step here to avoid it being unavailable @@ -264,10 +277,16 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): target=training_session.sess_str, graph=self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) - - self._workers.configure( - event_pb2.WorkerHeartbeatRequest( - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) + self._heartbeat_supported = self._workers.heartbeat_supported() + if self._heartbeat_supported: + self._workers.configure( + event_pb2.WorkerHeartbeatRequest( + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) + else: + logging.warn( + 'Worker heartbeats not supported by all workers. No failure ' + 'handling will be enabled.' + ) def saver(self): if self._saver: @@ -286,6 +305,9 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): del run_values + if not self._heartbeat_supported: + return + lame_workers = self._workers.lame_workers() if lame_workers: logging.info('ShutdownHook: lame workers found: %s', lame_workers) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 7b8786304c..c8f24ed01d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -58,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([ _MAX_WARNING_LINES = 5 _TPU_REPLICATE_ATTR = "_tpu_replicate" +_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" _OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" @@ -385,6 +386,45 @@ def replicate(computation, ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ + return split_compile_and_replicate(computation, inputs, infeed_queue, + device_assignment, name)[1] + + +def split_compile_and_replicate(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Builds graph operators that runs compilation and replicated computation. + + This is a lower level interface than replicate that returns a separate compile + and execute output tensor. In the generated graph the compile op feeds into + the execute op and no additional compilation is incurred when running the + compile op before the execute op. The compile op returns additional + information about the compilation but does not return the compiled program. + + Args: + computation: A Python function that builds the computation to replicate. + inputs: A list of lists of input tensors or `None` (equivalent to + `[[]]`), indexed by `[replica_num][input_num]`. All replicas must + have the same number of inputs. + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to computation. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each replica of the computation uses + only one core, and there is either only one replica, or the number of + replicas is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of lists with the first list corresponding to the compile op and the + second a list of output tensors, indexed by `[replica_num][output_num]`. + Raises: + ValueError: If all replicas do not have equal numbers of input tensors. + ValueError: If the number of inputs per replica does not match + the number of formal parameters to `computation`. + """ del name inputs = [[]] if inputs is None else inputs @@ -456,8 +496,8 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext( - name=graph.unique_name("cluster"), num_replicas=num_replicas) + cluster_name = graph.unique_name("cluster") + context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) try: context.Enter() @@ -516,8 +556,7 @@ def replicate(computation, # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( @@ -550,22 +589,33 @@ def replicate(computation, name="output{}".format(i)) for i in xrange(output_arity)] + with ops.control_dependencies([metadata]): + compile_status = tpu_ops.tpu_compilation_result() + op = compile_status.op + attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) + op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ - control_flow_ops.no_op(name="shard_%d" % i) - for i in range(num_replicas) + compile_status, [ + control_flow_ops.no_op(name="shard_%d" % i) + for i in range(num_replicas) + ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ - [array_ops.identity(outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity)] - for replica in xrange(num_replicas) + compile_status, [[ + array_ops.identity( + outputs[out][replica], + name="output_%d_shard_%d" % (out, replica)) + for out in xrange(output_arity) + ] + for replica in xrange(num_replicas)] ] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index a69bfa9a20..a624eceed9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -175,17 +175,7 @@ class _SIGNAL(object): STOP = -2 -class TPUEstimatorSpec( - collections.namedtuple('TPUEstimatorSpec', [ - 'mode', - 'predictions', - 'loss', - 'train_op', - 'eval_metrics', - 'export_outputs', - 'scaffold_fn', - 'host_call' - ])): +class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and @@ -1156,7 +1146,7 @@ class _ModelFnWrapper(object): self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op - if isinstance(estimator_spec, TPUEstimatorSpec): + if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access captured_scaffold_fn.capture(estimator_spec.scaffold_fn) else: captured_scaffold_fn.capture(None) @@ -1165,8 +1155,8 @@ class _ModelFnWrapper(object): # outfeed. with ops.control_dependencies([train_op]): host_call_outfeed_ops = [] - if (isinstance(estimator_spec, TPUEstimatorSpec) and - estimator_spec.host_call is not None): + if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access + and estimator_spec.host_call is not None): host_call.record({'host_call': estimator_spec.host_call}) host_call_outfeed_ops = host_call.create_enqueue_op() with ops.control_dependencies(host_call_outfeed_ops): @@ -1209,7 +1199,7 @@ class _ModelFnWrapper(object): features, labels = inputs.features_and_labels() tpu_estimator_spec = self._call_model_fn(features, labels) - if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): + if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) @@ -1254,7 +1244,7 @@ class _ModelFnWrapper(object): tpu_estimator_spec = self._call_model_fn( features, labels, is_export_mode=False) - if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): + if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access raise RuntimeError( 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) @@ -1316,7 +1306,7 @@ class _ModelFnWrapper(object): estimator_spec = self._model_fn(features=features, **kwargs) if (self._ctx.is_running_on_cpu(is_export_mode) and - isinstance(estimator_spec, TPUEstimatorSpec)): + isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() @@ -1325,7 +1315,7 @@ class _ModelFnWrapper(object): def _verify_estimator_spec(self, estimator_spec): """Validates the estimator_spec.""" - if isinstance(estimator_spec, TPUEstimatorSpec): + if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access return estimator_spec err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt index bf544703de..e230c51edf 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt @@ -1,5 +1,19 @@ op { graph_op_name: "MapAndBatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "other_arguments" + description: <<END +A list of tensors, typically values that were captured when building a closure +for `f`. +END + } in_arg { name: "batch_size" description: <<END @@ -11,13 +25,26 @@ END in_arg { name: "num_parallel_batches" description: <<END -A scalar representing the number of batches to create in -parallel. Processing multiple batches in parallel benefits workloads prone to -stragglers. +A scalar representing the number of batches to create in parallel. Processing +multiple batches in parallel benefits workloads prone to stragglers. +END + } + in_arg { + name: "drop_remainder" + description: <<END +A scalar representing whether the last batch should be dropped in case its size +is smaller than desired. +END + } + attr { + name: "f" + description: <<END +A function to apply to the outputs of `input_dataset`. END } - summary: "Creates a dataset that applies `f` to the outputs of `input_dataset` and then" + summary: "Creates a dataset that fuses mapping with batching." description: <<END +Creates a dataset that applies `f` to the outputs of `input_dataset` and then batches `batch_size` of them. Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt new file mode 100644 index 0000000000..81ef92cae0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt @@ -0,0 +1,54 @@ +op { + graph_op_name: "MapAndBatchDatasetV2" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "other_arguments" + description: <<END +A list of tensors, typically values that were captured when building a closure +for `f`. +END + } + in_arg { + name: "batch_size" + description: <<END +A scalar representing the number of elements to accumulate in a +batch. It determines the number of concurrent invocations of `f` that process +elements from `input_dataset` in parallel. +END + } + in_arg { + name: "num_parallel_calls" + description: <<END +A scalar representing the maximum number of parallel invocations of the `map_fn` +function. Applying the `map_fn` on consecutive input elements in parallel has +the potential to improve input pipeline throughput. +END + } + in_arg { + name: "drop_remainder" + description: <<END +A scalar representing whether the last batch should be dropped in case its size +is smaller than desired. +END + } + attr { + name: "f" + description: <<END +A function to apply to the outputs of `input_dataset`. +END + } + summary: "Creates a dataset that fuses mapping with batching." + description: <<END +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. +END +} diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index e389eb9b2a..7d63626b95 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -272,9 +272,9 @@ struct NodeItem { // (uint8 is enough for DataType). // EdgeInfo out_edges[num_out_edges]; // AllocatorAttributes output_attr[num_outputs]; + // int forward_from[num_outputs]; // uint8 input_type[num_inputs]; // uint8 output_type[num_outputs]; - // int forward_from[num_outputs]; // Return pointer to variable length section. char* var() const { @@ -289,22 +289,20 @@ struct NodeItem { return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) * num_output_edges); } + int* forward_from_base() const { + return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(AllocatorAttributes) * num_outputs); + } uint8* input_type_base() const { - return reinterpret_cast<uint8*>(var() + - sizeof(EdgeInfo) * num_output_edges + - sizeof(AllocatorAttributes) * num_outputs); + return reinterpret_cast<uint8*>( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); } uint8* output_type_base() const { return reinterpret_cast<uint8*>( var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs); - } - - int* forward_from_base() const { - return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + - sizeof(AllocatorAttributes) * num_outputs + - sizeof(uint8) * num_inputs + - sizeof(uint8) * num_outputs); + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + + sizeof(uint8) * num_inputs); } TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); @@ -481,9 +479,9 @@ size_t GraphView::NodeItemBytes(const Node* n) { sizeof(NodeItem) // Fixed + num_output_edges * sizeof(EdgeInfo) // output_edges[...] + num_outputs * sizeof(AllocatorAttributes) // output_attr[...] + + num_outputs * sizeof(int) // forward_from[num_outputs] + num_inputs * sizeof(uint8) // input_type[num_inputs] - + num_outputs * sizeof(uint8) // output_type[num_outputs] - + num_outputs * sizeof(int); // forward_from[num_outputs] + + num_outputs * sizeof(uint8); // output_type[num_outputs] static constexpr size_t kItemAlignment = sizeof(NodeItem*); static_assert(kItemAlignment % alignof(NodeItem) == 0, "NodeItem must be aligned with kItemAlignment"); diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index a6f637b488..bf05f6f1d9 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -795,16 +795,16 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, }; } - if (run_opts.runner == nullptr) { - run_opts.runner = &default_runner_; - } - DCHECK(run_opts.runner != nullptr); - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { parent_->Run(run_opts, handle, args, rets, done); return; } + if (run_opts.runner == nullptr) { + run_opts.runner = &default_runner_; + } + DCHECK(run_opts.runner != nullptr); + Executor::Args* exec_args = new Executor::Args; // Inherit the step_id from the caller. exec_args->step_id = run_opts.step_id; diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h index 9d31b1aecb..391dc8c198 100644 --- a/tensorflow/core/common_runtime/profile_handler.h +++ b/tensorflow/core/common_runtime/profile_handler.h @@ -29,22 +29,6 @@ class ProfileHandler { ProfileHandler() {} virtual ~ProfileHandler() {} - // Records that a miscellaneous activity occurred in the current step. - // - // Implementations of this method must be thread-safe. - // - // Args: - // - device: The device on which the activity occurred. - // - start: The time at which the activity started. - // - limit: The time at which the activity finished. - // - label: A label for the op, which may be used in visualization. - // - op_type: A type string for the op, which may be used in visualization. - // - details: A details string, which may be used in visualization. - // from time "start" to "limit" with "op_type" and "details". - virtual void RecordActivity(const string& device, Microseconds start, - Microseconds limit, StringPiece label, - StringPiece op_type, StringPiece details) = 0; - // Records that a single Op was executed in the current step. // // Implementations of this method must be thread-safe. diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 06dbe04986..fa4d1eda62 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -232,13 +232,12 @@ Status ShapeRefiner::AddNode(const Node* node) { input_nodes[e->dst_input()] = input; input_shapes[e->dst_input()] = c->output(e->src_output()); - // Only propagate handle data of edges which are carrying resource handles. - if (e->src()->output_type(e->src_output()) == DT_RESOURCE) { - const auto* in_v = c->output_handle_shapes_and_types(e->src_output()); - if (in_v != nullptr) { - input_handle_shapes_and_types[e->dst_input()].reset( - new std::vector<ShapeAndType>(*in_v)); - } + const auto* in_v = c->output_handle_shapes_and_types(e->src_output()); + if (in_v != nullptr) { + DataType input_type = e->src()->output_type(e->src_output()); + DCHECK(input_type == DT_RESOURCE || input_type == DT_VARIANT); + input_handle_shapes_and_types[e->dst_input()].reset( + new std::vector<ShapeAndType>(*in_v)); } } @@ -422,6 +421,28 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, kMaxTensorSize, disable_constant_propagation_); } +Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node, + int dst_idx, bool* evaluated, + int64* result) { + Tensor scalar; + TF_RETURN_IF_ERROR( + EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar)); + if (*evaluated) { + DCHECK_EQ(scalar.NumElements(), 1) + << "EvaluateConstantIntScalarEdge called on non-scalar edge: " + << scalar.NumElements(); + if (scalar.dtype() == DT_INT32) { + *result = scalar.scalar<int32>()(); + } else { + DCHECK_EQ(scalar.dtype(), DT_INT64) + << "EvaluateConstantIntScalarEdge called on non-integer edge: " + << scalar.dtype(); + *result = scalar.scalar<int64>()(); + } + } + return Status::OK(); +} + Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, const Node* node, int dst_idx, ShapeHandle* result) { @@ -472,19 +493,11 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, std::vector<DimensionHandle> dims; // Pack is concatenating its input scalars to form the shape tensor vector. for (int i = 0; i < src_context->num_inputs(); ++i) { - Tensor scalar; - bool evaluated = false; - TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i, - &evaluated, &scalar)); + int64 size; + bool evaluated; + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i, + &evaluated, &size)); if (evaluated) { - int64 size; - if (scalar.dtype() == DT_INT32) { - size = scalar.scalar<int32>()(); - } else if (scalar.dtype() == DT_INT64) { - size = scalar.scalar<int64>()(); - } else { - return errors::InvalidArgument("Pack input must be int32 or int64"); - } dims.push_back(size < 0 ? target_context->UnknownDim() : target_context->MakeDim(size)); } else { @@ -514,6 +527,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, TF_RETURN_IF_ERROR( target_context->Concatenate(*result, sub_result, result)); } + } else if (src_op == "StridedSlice") { + TF_RETURN_IF_ERROR( + PartialStridedSliceShape(input_edge->src(), src_context, result)); } else { Tensor t; bool evaluated = false; @@ -525,6 +541,78 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, return Status::OK(); } +Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, + InferenceContext* ctx, + ShapeHandle* result) { + // Only attempt to evaluate if begin/end/strides all are scalars. + for (int i = 1; i <= 3; ++i) { + ShapeHandle input_shape = ctx->input(i); + if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask)); + TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask)); + + // Only attempt to evaluate if there are no special masks set (note that we + // can handle begin/end_mask == 1). + if (!(begin_mask == 0 || begin_mask == 1) || + !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 || + new_axis_mask != 0 || shrink_axis_mask != 0) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + + bool evaluated; + int64 begin; + if (begin_mask == 1) { + begin = 0; + } else { + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int64 end; + if (end_mask == 1) { + end = std::numeric_limits<int64>::max(); + } else { + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int64 stride; + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + + // Apply stride to input interpreted as a partial shape. + ShapeHandle input; + TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input)); + TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result)); + return Status::OK(); +} + Status ShapeRefiner::RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec) { diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index d49c4373f0..9c96dcbc20 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -215,9 +215,18 @@ class ShapeRefiner { bool keep_nested_shapes, ExtendedInferenceContext* outer_context); + // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge + // value can be evaluated, 'evaluated' is set to true and the value returned + // in 'result'. Otherwise 'evaluated' is set to false. Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result); + // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input + // tensors. The caller is responsible for checking that the specified edge is + // scalar and int32 or int64. + Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx, + bool* evaluated, int64* result); + // This function tries to materialize as much information about the 'node''s // dst_idx input as a statically computable shape, and the result may be // partially known, depending on what is statically inferable. @@ -243,6 +252,11 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); + // Implementation of ConstantPartialShape for StridedSlice nodes. + Status PartialStridedSliceShape(Node* slice_node, + shape_inference::InferenceContext* ctx, + shape_inference::ShapeHandle* result); + Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec); diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index f48638afc0..8b9657eec8 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -60,6 +60,39 @@ class ShapeRefinerTest : public ::testing::Test { } static constexpr int64 kMaxTensorSize = ShapeRefiner::kMaxTensorSize; + + void TestStridedSlice(const PartialTensorShape& input_shape, int begin, + int end, int stride, const char* expected, + int begin_mask = 0, int end_mask = 0, + int ellipsis_mask = 0) { + Scope root = Scope::DisabledShapeInferenceScope(); + auto placeholder = + ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape)); + auto input = ops::Shape(root, placeholder); + auto begin_op = ops::Const(root, {begin}); + auto end_op = ops::Const(root, {end}); + auto stride_op = ops::Const(root, {stride}); + auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op, + ops::StridedSlice::BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask)); + Node* result; + TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32") + .Input(slice.node()) + .Finalize(root.graph(), &result)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(placeholder.node())); + TF_ASSERT_OK(m.AddNode(input.node())); + TF_ASSERT_OK(m.AddNode(begin_op.node())); + TF_ASSERT_OK(m.AddNode(end_op.node())); + TF_ASSERT_OK(m.AddNode(stride_op.node())); + TF_ASSERT_OK(m.AddNode(slice.node())); + TF_ASSERT_OK(m.AddNode(result)); + + shape_inference::InferenceContext* ctx = m.GetContext(result); + EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected); + } }; namespace { @@ -1156,6 +1189,73 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { m.AddNode(result).error_message()); } +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/2, + /*end=*/5, + /*stride=*/1, + /*expected=*/"[3,?,5]"); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) { + // clang-format off + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/10, + /*end=*/0, + /*stride=*/-1, + /*expected=*/"[5,?,3,?]"); + // clang-format on +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/3, + /*end=*/4, + /*stride=*/1, + /*expected=*/"[1,?,3,?,5]", + /*begin_mask=*/1, + /*end_mask=*/1); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3}, + /*begin=*/2, + /*end=*/3, + /*stride=*/1, + /*expected=*/"[?,?,?]", + /*begin_mask=*/0, + /*end_mask=*/0, + /*ellipsis_mask=*/1); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) { + Scope root = Scope::DisabledShapeInferenceScope(); + auto input = ops::Placeholder(root, DT_INT32); + auto begin = ops::Const(root, {0, 0}); + auto end = ops::Const(root, {2, 2}); + auto stride = ops::Const(root, {1, 1}); + auto slice = ops::StridedSlice(root, input, begin, end, stride); + Node* result; + TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32") + .Input(slice.node()) + .Finalize(root.graph(), &result)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(input.node())); + TF_ASSERT_OK(m.AddNode(begin.node())); + TF_ASSERT_OK(m.AddNode(end.node())); + TF_ASSERT_OK(m.AddNode(stride.node())); + TF_ASSERT_OK(m.AddNode(slice.node())); + TF_ASSERT_OK(m.AddNode(result)); + + shape_inference::InferenceContext* ctx = m.GetContext(result); + EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?"); +} + namespace { // Dummy op to test ShapeRefiner util functions diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto index cce02d84b2..3f8dd272e7 100644 --- a/tensorflow/core/framework/api_def.proto +++ b/tensorflow/core/framework/api_def.proto @@ -56,8 +56,10 @@ message ApiDef { // use a snake_case convention instead of CamelCase. string name = 1; - // First GraphDef version at which the op is disallowed. - int32 deprecation_version = 2; + // If this endpoint is deprecated, set deprecation_message to a + // message that should be logged when the endpoint is used. + // The message should indicate alternative endpoint to use, if any. + string deprecation_message = 2; } repeated Endpoint endpoint = 3; diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 4145ef7bc9..62a9d5751d 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" @@ -269,4 +270,22 @@ const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = "_DATASET_GRAPH_OUTPUT_NODE"; +namespace dataset { + +IteratorContext MakeIteratorContext(OpKernelContext* ctx) { + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.lib = ctx->function_library(); + // Note: must use reinterpret_cast because function.h forward-declares Device. + DeviceBase* device = + reinterpret_cast<DeviceBase*>(ctx->function_library()->device()); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + return IteratorContext(params); +} + +} // namespace dataset + } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 775d9f6eb6..8624af9bf5 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -619,6 +619,12 @@ Status GetDatasetFromVariantTensor(const Tensor& tensor, // The ownership of `dataset` is transferred to `tensor`. Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); +namespace dataset { + +IteratorContext MakeIteratorContext(OpKernelContext* ctx); + +} // namespace dataset + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index bdc1af9fda..647c66099c 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -504,7 +504,7 @@ string Print(const NodeDef& n) { std::vector<string> dep; for (StringPiece s : n.input()) { if (str_util::ConsumePrefix(&s, "^")) { - dep.push_back(s.ToString()); + dep.push_back(std::string(s)); } else { dat.push_back(s); } diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index f9cf6ce873..8e00bfe4f8 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -24,22 +24,23 @@ limitations under the License. namespace tensorflow { NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) - : node(n.ToString()), index(i), data_type(dt) {} + : node(std::string(n)), index(i), data_type(dt) {} NodeDefBuilder::NodeOut::NodeOut() { // uninitialized, call Reset() before use. } void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { - node = n.ToString(); + node = std::string(n); index = i; data_type = dt; } NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry) { - node_def_.set_name(name.ToString()); - const Status status = op_registry->LookUpOpDef(op_name.ToString(), &op_def_); + node_def_.set_name(std::string(name)); + const Status status = + op_registry->LookUpOpDef(std::string(op_name), &op_def_); if (status.ok()) { Initialize(); } else { @@ -50,7 +51,7 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) : op_def_(op_def) { - node_def_.set_name(name.ToString()); + node_def_.set_name(std::string(name)); Initialize(); } @@ -170,7 +171,7 @@ void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { } else if (src_index > 0) { node_def_.add_input(strings::StrCat(src_node, ":", src_index)); } else { - node_def_.add_input(src_node.ToString()); + node_def_.add_input(std::string(src_node)); } } @@ -193,12 +194,12 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, } NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { - control_inputs_.push_back(src_node.ToString()); + control_inputs_.push_back(std::string(src_node)); return *this; } NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { - node_def_.set_device(device_spec.ToString()); + node_def_.set_device(std::string(device_spec)); return *this; } diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index bad92ca9b3..5798333dfe 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -245,7 +245,7 @@ DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { - return node_def.attr().find(attr_name.ToString()) != node_def.attr().end(); + return node_def.attr().find(std::string(attr_name)) != node_def.attr().end(); } static const string& kEmptyString = *new string(); @@ -639,7 +639,7 @@ Status AttachDef(const Status& status, const Node& node) { void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { node_def->mutable_attr()->insert( - AttrValueMap::value_type(name.ToString(), value)); + AttrValueMap::value_type(std::string(name), value)); } #define ADD_NODE_ATTR(T) \ @@ -677,7 +677,7 @@ ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>) #undef ADD_NODE_ATTR void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { - map->insert(AttrValueMap::value_type(name.ToString(), value)); + map->insert(AttrValueMap::value_type(std::string(name), value)); } #define ADD_ATTR(T) \ diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 403bd0b5e2..91eb6c0672 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -527,7 +527,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, } // namespace OpDefBuilder::OpDefBuilder(StringPiece op_name) { - op_def()->set_name(op_name.ToString()); // NOLINT + op_def()->set_name(std::string(op_name)); // NOLINT } OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { @@ -584,7 +584,7 @@ OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) { } else { OpDeprecation* deprecation = op_def()->mutable_deprecation(); deprecation->set_version(version); - deprecation->set_explanation(explanation.ToString()); + deprecation->set_explanation(std::string(explanation)); } return *this; } diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 7f23272871..3d7920a6e2 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -185,7 +185,7 @@ static bool FindMultiline(StringPiece line, size_t colon, string* end) { while (str_util::ConsumePrefix(&line, " ")) { } if (str_util::ConsumePrefix(&line, "<<")) { - *end = line.ToString(); + *end = std::string(line); return true; } return false; @@ -306,9 +306,6 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { auto* endpoint = api_def->add_endpoint(); endpoint->set_name(op_def.name()); - if (op_def.has_deprecation()) { - endpoint->set_deprecation_version(op_def.deprecation().version()); - } for (const auto& op_in_arg : op_def.input_arg()) { auto* api_in_arg = api_def->add_in_arg(); diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc index 857b1c8dbc..e0e77c7449 100644 --- a/tensorflow/core/framework/op_gen_lib_test.cc +++ b/tensorflow/core/framework/op_gen_lib_test.cc @@ -189,7 +189,6 @@ TEST(OpGenLibTest, ApiDefInitializedFromOpDef) { visibility: VISIBLE endpoint { name: "testop" - deprecation_version: 123 } in_arg { name: "arg_a" diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index ca91d68f79..c71bcb26ab 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -923,7 +923,7 @@ void OpKernelContext::clear_recorded_memory() { struct KernelRegistration { KernelRegistration(const KernelDef& d, StringPiece c, kernel_factory::OpKernelRegistrar::Factory f) - : def(d), kernel_class_name(c.ToString()), factory(f) {} + : def(d), kernel_class_name(std::string(c)), factory(f) {} const KernelDef def; const string kernel_class_name; const kernel_factory::OpKernelRegistrar::Factory factory; diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index c84ea3b034..3cc17e1ca6 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -338,6 +338,9 @@ class ResourceHandleOp : public OpKernel { private: string container_; string name_; + mutex mutex_; + Tensor resource_ GUARDED_BY(mutex_); + std::atomic<bool> initialized_{false}; }; // Registers a kernel for an op which produces a handle to a resource of the @@ -511,10 +514,17 @@ ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context) template <typename T> void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) { - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - output->scalar<ResourceHandle>()() = - MakeResourceHandle<T>(ctx, container_, name_); + if (!initialized_.load()) { + mutex_lock ml(mutex_); + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &resource_, attr)); + resource_.scalar<ResourceHandle>()() = + MakeResourceHandle<T>(ctx, container_, name_); + initialized_.store(true); + } + ctx->set_output(0, resource_); } } // end namespace tensorflow diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 2b995e8b5e..3185875e3b 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -605,10 +605,16 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start, return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out); } -Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, +Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, ShapeHandle* out) { - int64 start = start_in; - int64 end = end_in; + return Subshape(s, start, end, 1 /* stride */, out); +} + +Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, + int64 stride, ShapeHandle* out) { + int64 start_in = start; + int64 end_in = end; + const int32 rank = Rank(s); if (start == 0 && ((RankKnown(s) && end >= rank) || end == std::numeric_limits<int64>::max())) { @@ -621,6 +627,9 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, if (start > rank) start = rank; if (end > rank) end = rank; + + if (stride < 0 && start == rank) --start; + if (start < 0) { start = rank + start; if (start < 0) { @@ -638,16 +647,24 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, ", for shape with rank ", rank); } } - if (start > end) { + if (stride > 0 && start > end) { *out = nullptr; return errors::InvalidArgument( "Subshape must have computed start <= end, but is ", start, " and ", end, " (computed from start ", start_in, " and end ", end_in, " over shape with rank ", rank, ")"); + } else if (stride < 0 && start < end) { + *out = nullptr; + return errors::InvalidArgument( + "Subshape must have computed start >= end since stride is negative, " + "but is ", + start, " and ", end, " (computed from start ", start_in, " and end ", + end_in, " over shape with rank ", rank, " and stride", stride, ")"); } + std::vector<DimensionHandle> dims; - dims.reserve(end - start); - for (int i = start; i < end; ++i) { + dims.reserve((end - start) / stride); + for (int i = start; stride > 0 ? i < end : i > end; i += stride) { dims.push_back(Dim(s, i)); } return ReturnCreatedShape(dims, out); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 9431a62abe..3f3729dcf9 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -434,6 +434,13 @@ class InferenceContext { Status Subshape(ShapeHandle s, int64 start, int64 end, ShapeHandle* out) TF_MUST_USE_RESULT; + // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride]. + // <start> and <end> can be negative, to index from the end of the shape. + // <start> and <end> are set to the rank of <s> if > rank of <s>. + // <stride> can be negative, to reverse the <s>. + Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride, + ShapeHandle* out) TF_MUST_USE_RESULT; + // Returns in <*out> the result of appending the dimensions of <s2> to those // of <s1>. Status Concatenate(ShapeHandle s1, ShapeHandle s2, diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 2a99af7659..f6656b3b45 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -32,7 +32,7 @@ class Tensor; struct ShapeInferenceTestOp { typedef std::pair<string, DataType> ShapeAndType; - explicit ShapeInferenceTestOp(StringPiece name) : name(name.ToString()) {} + explicit ShapeInferenceTestOp(StringPiece name) : name(std::string(name)) {} string name; NodeDef node_def; std::vector<const Tensor*> input_tensors; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index eeb6c60f71..71d0637dc2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -695,7 +695,7 @@ Status Graph::AddWhileContext(StringPiece frame_name, std::vector<OutputTensor> body_outputs, WhileContext** result) { auto pair = while_ctxs_.insert(std::pair<string, WhileContext>( - frame_name.ToString(), + std::string(frame_name), WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), cond_output, std::move(body_inputs), std::move(body_outputs)))); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index c678283fce..2fd32c0bd4 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -489,7 +489,7 @@ Status GraphConstructor::InitFromEdges() { num_control_edges++; } else { TensorId id(ParseTensorName(input_name)); - if (next_iteration_nodes_.find(id.first.ToString()) != + if (next_iteration_nodes_.find(std::string(id.first)) != next_iteration_nodes_.end()) { has_loop_back_edge = true; } @@ -811,7 +811,7 @@ void GraphConstructor::UniquifyNames( // We require that UniquifyNames() is called on all NodeDefs in topological // order. This guarantees that node_def's inputs will already be uniquified // if necessary. - auto iter = uniquified_names_.find(id.first.ToString()); + auto iter = uniquified_names_.find(std::string(id.first)); if (iter == uniquified_names_.end()) continue; id.first = iter->second; node_def->set_input(i, id.ToString()); @@ -830,7 +830,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() { for (int i = 0; i < coloc_values.size(); ++i) { StringPiece val(coloc_values[i]); if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) { - const auto& name_pair = uniquified_names_.find(val.ToString()); + const auto& name_pair = uniquified_names_.find(std::string(val)); if (name_pair == uniquified_names_.end()) continue; updated = true; coloc_values[i] = @@ -856,7 +856,7 @@ bool GraphConstructor::NameExistsInGraphDef(StringPiece name) { } string GraphConstructor::FindUniqueName(StringPiece original_name) { - string name = original_name.ToString(); + string name = std::string(original_name); int count = 0; // Check that any generated names don't collide with imported NodeDefs (as // well as nodes in g_). @@ -989,7 +989,7 @@ Status GraphConstructor::Convert() { src_node->num_outputs(), " outputs"); } - inputs.emplace_back(id.first.ToString(), src_node, src_index); + inputs.emplace_back(std::string(id.first), src_node, src_index); } if (has_data_back_edge && !IsMerge(*node_def)) { diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index b513778de9..c54b4fa269 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -157,7 +157,7 @@ class GraphConstructorTest : public ::testing::Test { } StringPiece loc(value[0]); return str_util::ConsumePrefix(&loc, kColocationGroupPrefix) - ? loc.ToString() + ? std::string(loc) : ""; } diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index 7a58347bd1..dd84c4f7c7 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -44,12 +44,12 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs( } GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl( StringPiece name) { - name_ = name.ToString(); + name_ = std::string(name); return *this; } GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl( StringPiece device) { - device_ = device.ToString(); + device_ = std::string(device); return *this; } GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl( diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index 776a74c6d8..0d6aae4355 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -128,7 +128,7 @@ class GraphDefBuilder { Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); template <class T> Options WithAttrImpl(StringPiece name, T&& value) { - attrs_.emplace_back(name.ToString(), AttrValue()); + attrs_.emplace_back(std::string(name), AttrValue()); SetAttrValue(std::forward<T>(value), &attrs_.back().second); return *this; } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 877e4f1b44..1b1941f9c1 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -785,7 +785,7 @@ Status TopologicalSortNodesWithTimePriority( for (int n = 0; n < gdef->node_size(); ++n) { const NodeDef* ndef = &gdef->node(n); for (int i = 0; i < ndef->input_size(); ++i) { - node_to_output_nodes[ParseTensorName(ndef->input(i)).first.ToString()] + node_to_output_nodes[std::string(ParseTensorName(ndef->input(i)).first)] .push_back(ndef); } int64 start_time; diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 114962c0e4..03f3bbd663 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -30,7 +30,7 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit) dt(SafeGetOutput(node, i, &error)) {} NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t) - : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {} + : node(nullptr), error(false), name(std::string(n)), index(i), dt(t) {} NodeBuilder::NodeOut::NodeOut() : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc index 10a2b67f37..1b38aac35d 100644 --- a/tensorflow/core/graph/while_context.cc +++ b/tensorflow/core/graph/while_context.cc @@ -23,7 +23,7 @@ WhileContext::WhileContext(StringPiece frame_name, OutputTensor cond_output, std::vector<OutputTensor> body_inputs, std::vector<OutputTensor> body_outputs) - : frame_name_(frame_name.ToString()), + : frame_name_(std::string(frame_name)), enter_nodes_(std::move(enter_nodes)), exit_nodes_(std::move(exit_nodes)), cond_output_(cond_output), diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index b35873ce38..2542fa2d67 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -27,10 +27,16 @@ namespace grappler { constexpr int kOpsPerMac = 2; constexpr char kConst[] = "Const"; +constexpr char kGuaranteeConst[] = "GuaranteeConst"; constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput"; constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation"; +constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative"; +constexpr char kDepthwiseConv2dNativeBackpropFilter[] = + "DepthwiseConv2dNativeBackpropFilter"; +constexpr char kDepthwiseConv2dNativeBackpropInput[] = + "DepthwiseConv2dNativeBackpropInput"; constexpr char kMatMul[] = "MatMul"; constexpr char kSparseMatMul[] = "SparseMatMul"; constexpr char kPlaceholder[] = "Placeholder"; @@ -200,11 +206,20 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, {kFusedConv2dBiasActivation, wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)}, + // reuse Conv2D for DepthwiseConv2dNative because the caculation is the + // same although the actual meaning of the parameters are different. See + // comments in PredictConv2D and related functions + {kDepthwiseConv2dNative, wrap(&OpLevelCostEstimator::PredictConv2D)}, + {kDepthwiseConv2dNativeBackpropFilter, + wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)}, + {kDepthwiseConv2dNativeBackpropInput, + wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)}, {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}, {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kGuaranteeConst, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, {kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)}, @@ -537,18 +552,30 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs( int64 OpLevelCostEstimator::CountConv2DOperations( const OpInfo& op_features, ConvolutionDimensions* conv_info, bool* found_unknown_shapes) const { - if (op_features.op() != kConv2d) { - LOG(ERROR) << "Invalid Operation"; - return 0; - } + DCHECK(op_features.op() == kConv2d || + op_features.op() == kDepthwiseConv2dNative) + << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative"; + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features, found_unknown_shapes); + // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth + // multiplier; The effective output channel depth oz_effective is + // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS. + // Compare to Conv2D where # ops = N x H x W x iz x oz x 2RS, + // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = iz. int64 ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; + if (op_features.op() == kConv2d) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // To ensure output tensor dims to be correct for DepthwiseConv2DNative, + // although ops are the same as Conv2D. + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } ops *= kOpsPerMac; if (conv_info != nullptr) { @@ -795,7 +822,10 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( bool* found_unknown_shapes) const { int64 ops = 0; - DCHECK_EQ(kConv2dBackpropInput, op_features.op()); + DCHECK(op_features.op() == kConv2dBackpropInput || + op_features.op() == kDepthwiseConv2dNativeBackpropInput) + << "Invalid Operation: not kConv2dBackpropInput nor" + "kDepthwiseConv2dNativeBackpropInput"; if (op_features.inputs_size() < 2) { *found_unknown_shapes = true; @@ -828,10 +858,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; - ops *= kOpsPerMac; + if (op_features.op() == kConv2dBackpropInput) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // conv_dims always use forward path definition regardless + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } - VLOG(1) << "Operations for Conv2DBackpropInput " << ops; + VLOG(1) << "Operations for" << op_features.op() << " " << ops; if (returned_conv_dims != nullptr) { *returned_conv_dims = conv_dims; @@ -843,7 +878,11 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, bool* found_unknown_shapes) const { int64 ops = 0; - DCHECK_EQ(kConv2dBackpropFilter, op_features.op()); + + DCHECK(op_features.op() == kConv2dBackpropFilter || + op_features.op() == kDepthwiseConv2dNativeBackpropFilter) + << "Invalid Operation: not kConv2dBackpropFilter nor" + "kDepthwiseConv2dNativeBackpropFilter"; TensorShapeProto filter_shape; bool shape_found = false; @@ -875,10 +914,15 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( ops = conv_dims.batch; ops *= conv_dims.ox * conv_dims.oy; ops *= conv_dims.kx * conv_dims.ky; - ops *= conv_dims.iz * conv_dims.oz; - ops *= kOpsPerMac; + if (op_features.op() == kConv2dBackpropFilter) { + ops *= conv_dims.iz * conv_dims.oz; + } else { + // conv_dims always use forward path definition regardless + conv_dims.oz *= conv_dims.iz; + ops *= conv_dims.oz; + } - VLOG(1) << "Operations for Conv2DBackpropFilter" << ops; + VLOG(1) << "Operations for" << op_features.op() << " " << ops; if (returned_conv_dims != nullptr) { *returned_conv_dims = conv_dims; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 13ea43bed6..b2c021b73a 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -128,6 +128,23 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, return op_context; } +// Describe DepthwiseConvolution constructs an OpContext for a +// DepthwiseConv2dNative applied to an input +// tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape +// (kx, ky, iz2, cm). cm is channel multiplier + +OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1, + int iz2, int kx, int ky, int cm) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("DepthwiseConv2dNative"); + + DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs()); + DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs()); + + return op_context; +} + // DescribeFusedConv2DBiasActivation constructs an OpContext for a // FusedConv2DBiasActivation applied to a convolution input tensor with shape // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a @@ -505,6 +522,15 @@ TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) { EXPECT_FALSE(cost.inaccurate); } +TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) { + auto cost = + PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3)); + EXPECT_EQ(Costs::Duration(112340), cost.memory_time); + EXPECT_EQ(Costs::Duration(4158720), cost.compute_time); + EXPECT_EQ(Costs::Duration(4271060), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) { auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5b5e1e024e..900dfa95c5 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -604,6 +604,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 3f9feac55f..1f6f563687 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; bool remove_negation = true; - bool hoist_cwise_unary_chains = true; + bool hoist_cwise_unary_chains = false; bool convert_sqrt_div_to_rsqrt_mul = false; bool remove_idempotent = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e109e66633..067adb359c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -696,6 +696,9 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) { item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + ArithmeticOptimizer optimizer; EnableOnlyHoistCommonFactor(&optimizer); @@ -734,6 +737,13 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) { EXPECT_EQ("id", id_node->name()); EXPECT_EQ(HoistDivName("add"), id_node->input(0)); } + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + if (use_ints) { + test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]); + } else { + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + } } } } @@ -1156,6 +1166,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveIdentityTranspose(&optimizer); @@ -1168,6 +1183,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) { EXPECT_EQ(node.input(2), "Split:2"); } } + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { @@ -1184,6 +1203,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3})); + item.feed = {{"Placeholder", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveIdentityTranspose(&optimizer); @@ -1194,6 +1218,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { EXPECT_EQ(2, outputs_node->input_size()); EXPECT_EQ(outputs_node->input(0), "outputs_const"); EXPECT_EQ(outputs_node->input(1), "^Placeholder"); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { @@ -1440,6 +1468,11 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantBitcast(&optimizer); @@ -1451,6 +1484,10 @@ TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { EXPECT_EQ(3, output.node_size()); EXPECT_EQ(1, CountOpNodes(output, "Bitcast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { @@ -1465,6 +1502,11 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantBitcast(&optimizer); @@ -1476,6 +1518,10 @@ TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { EXPECT_EQ(2, output.node_size()); EXPECT_EQ(0, CountOpNodes(output, "Bitcast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { @@ -1489,6 +1535,11 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3})); + item.feed = {{"inputs", x_t}}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveRedundantCast(&optimizer); @@ -1500,6 +1551,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { EXPECT_EQ(2, output.node_size()); EXPECT_EQ(0, CountOpNodes(output, "Cast")); EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); + + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorEqual<int8>(tensors_expected[0], tensors[0]); } TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 47d8827686..e6a74dbdcd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2370,115 +2370,124 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } } - // Partial constant folding for Concat which is not commutative, so - // we have to preserve order and can only push consecutive runs of constant - // inputs into sub-nodes. - if (IsConcat(*node) && num_non_control_inputs > 3 && - node->name().rfind("_partial_split_") == string::npos) { - int axis_arg = -1; - int begin = 0; - int end = num_non_control_inputs; - if (node->op() == "Concat") { - begin = 1; - axis_arg = 0; - } else if (node->op() == "ConcatV2") { - end = num_non_control_inputs - 1; - axis_arg = num_non_control_inputs - 1; - } else { - continue; - } + if (PartialConcatConstFolding(optimized_graph, properties, node)) { + graph_modified_ = true; + continue; + } + } - const NodeDef* axis_arg_node = - node_map_->GetNode(NodeName(node->input(axis_arg))); - if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) { - // We cannot constant fold Concat unless we the axis argument is - // constant. Skip node. - continue; - } + return Status::OK(); +} - // We search for consecutive runs of constant inputs in the range - // [begin:end[ and push then down into child nodes. - std::vector<std::pair<int, int>> constant_input_runs; - int first = begin; - int last = begin; - while (last < end) { - while (first < end && !IsReallyConstant(*node_map_->GetNode( - NodeName(node->input(first))))) { - ++first; - } - // Invariant: node[first] is constant || first >= end. - last = first + 1; - while (last < end && IsReallyConstant(*node_map_->GetNode( - NodeName(node->input(last))))) { - ++last; - } - // Invariant: node[last] is not constant || last >= end - // Discard intervals shorter than 2 elements. - if (first < end && (last - first) > 1) { - constant_input_runs.emplace_back(first, last); - } - first = last; +bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph, + GraphProperties* properties, + NodeDef* node) { + // Partial constant folding for Concat which is not commutative, so + // we have to preserve order and can only push consecutive runs of constant + // inputs into sub-nodes. + const int num_non_control_inputs = NumNonControlInputs(*node); + if (IsConcat(*node) && num_non_control_inputs > 3 && + node->name().rfind("_partial_split_") == string::npos) { + int axis_arg = -1; + int begin = 0; + int end = num_non_control_inputs; + if (node->op() == "Concat") { + begin = 1; + axis_arg = 0; + } else if (node->op() == "ConcatV2") { + end = num_non_control_inputs - 1; + axis_arg = num_non_control_inputs - 1; + } else { + return false; + } + + const NodeDef* axis_arg_node = + node_map_->GetNode(NodeName(node->input(axis_arg))); + if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) { + // We cannot constant fold Concat unless we the axis argument is + // constant. Skip node. + return false; + } + + // We search for consecutive runs of constant inputs in the range + // [begin:end[ and push then down into child nodes. + std::vector<std::pair<int, int>> constant_input_runs; + int first = begin; + int last = begin; + while (last < end) { + while (first < end && !IsReallyConstant(*node_map_->GetNode( + NodeName(node->input(first))))) { + ++first; + } + // Invariant: node[first] is constant || first >= end. + last = first + 1; + while (last < end && IsReallyConstant(*node_map_->GetNode( + NodeName(node->input(last))))) { + ++last; } + // Invariant: node[last] is not constant || last >= end + // Discard intervals shorter than 2 elements. + if (first < end && (last - first) > 1) { + constant_input_runs.emplace_back(first, last); + } + first = last; + } - // Skip if all inputs are constant, and let constant folding take over. - if (constant_input_runs.size() == 1 && - constant_input_runs[0].first == begin && - constant_input_runs[0].second == end) { - continue; + // Skip if all inputs are constant, and let constant folding take over. + if (constant_input_runs.size() == 1 && + constant_input_runs[0].first == begin && + constant_input_runs[0].second == end) { + return false; + } + std::set<int> inputs_to_delete; + for (auto interval : constant_input_runs) { + // Push the constant inputs in the interval to a child node than can be + // constant folded. + const string new_node_name = OptimizedNodeName( + *node, strings::StrCat("_partial_split_", interval.first)); + if (node_map_->NodeExists(new_node_name)) { + break; } - std::set<int> inputs_to_delete; - for (auto interval : constant_input_runs) { - // Push the constant inputs in the interval to a child node than can be - // constant folded. - const string new_node_name = OptimizedNodeName( - *node, strings::StrCat("_partial_split_", interval.first)); - if (node_map_->NodeExists(new_node_name)) { - break; - } - NodeDef* added_node = optimized_graph->add_node(); - *added_node = *node; - added_node->set_name(new_node_name); - node_map_->AddNode(added_node->name(), added_node); - added_node->clear_input(); - for (int i = interval.first; i < interval.second; ++i) { - added_node->add_input(node->input(i)); - node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), - added_node->name()); - if (i != interval.first) { - inputs_to_delete.insert(i); - } + NodeDef* added_node = optimized_graph->add_node(); + *added_node = *node; + added_node->set_name(new_node_name); + node_map_->AddNode(added_node->name(), added_node); + added_node->clear_input(); + for (int i = interval.first; i < interval.second; ++i) { + added_node->add_input(node->input(i)); + node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), + added_node->name()); + if (i != interval.first) { + inputs_to_delete.insert(i); } - added_node->add_input(node->input(axis_arg)); - (*added_node->mutable_attr())["N"].set_i(interval.second - - interval.first); - node_map_->AddOutput(NodeName(node->input(axis_arg)), - added_node->name()); - - // Overwrite the first constant input with the result of the added - // child node. - node->set_input(interval.first, added_node->name()); - node_map_->AddOutput(added_node->name(), node->name()); } - if (!constant_input_runs.empty()) { - graph_modified_ = true; - if (!inputs_to_delete.empty()) { - // Fix up the inputs to the original node. - std::vector<string> tmp(node->input().begin(), node->input().end()); - node->clear_input(); - for (int i = 0; i < tmp.size(); ++i) { - if (inputs_to_delete.find(i) == inputs_to_delete.end()) { - node->add_input(tmp[i]); - } + added_node->add_input(node->input(axis_arg)); + (*added_node->mutable_attr())["N"].set_i(interval.second - + interval.first); + node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name()); + + // Overwrite the first constant input with the result of the added + // child node. + node->set_input(interval.first, added_node->name()); + node_map_->AddOutput(added_node->name(), node->name()); + } + if (!constant_input_runs.empty()) { + if (!inputs_to_delete.empty()) { + // Fix up the inputs to the original node. + std::vector<string> tmp(node->input().begin(), node->input().end()); + node->clear_input(); + for (int i = 0; i < tmp.size(); ++i) { + if (inputs_to_delete.find(i) == inputs_to_delete.end()) { + node->add_input(tmp[i]); } - (*node->mutable_attr())["N"].set_i(node->input_size() - 1); - properties->ClearInputProperties(node->name()); } - continue; + (*node->mutable_attr())["N"].set_i(node->input_size() - 1); + properties->ClearInputProperties(node->name()); } + return true; } } - - return Status::OK(); + return false; } Status ConstantFolding::RunOptimizationPass(Cluster* cluster, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index a694f1721a..2096576538 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -101,6 +101,11 @@ class ConstantFolding : public GraphOptimizer { Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* output); + // Applies partial constant folding for Concat which is not commutative. + // Returns true if the transformation applied successfully. + bool PartialConcatConstFolding(GraphDef* optimized_graph, + GraphProperties* properties, NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 1bec9086f7..a44e1ee7f9 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/function_optimizer.h" + #include <unordered_map> + #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -74,6 +77,73 @@ string UniqueSpecializedFunctionName(const FunctionDef& func, return unique_name; } +// Specialized function instantiation type parameters, body parameters, and +// const inputs. +struct FunctionSpecializationSignature { + string func_name; + std::unordered_map<string, DataType> type_parameters; + std::unordered_map<string, AttrValue> body_parameters; + std::unordered_map<int, string> const_inputs; + + bool operator==(const FunctionSpecializationSignature& other) const { + bool equals = func_name == other.func_name && + type_parameters == other.type_parameters && + const_inputs == other.const_inputs; + + if (!equals) return false; + + // Equality is not defined for AttrValue. + if (body_parameters.size() != other.body_parameters.size()) return false; + + for (const auto& lhs : body_parameters) { + auto it = other.body_parameters.find(lhs.first); + if (it == other.body_parameters.end()) return false; + if (!AreAttrValuesEqual(lhs.second, (*it).second)) return false; + } + + return true; + } + + struct Hash { + uint64 operator()(FunctionSpecializationSignature const& s) const { + uint64 h = Hash64(s.func_name); + + // Use std::map for deterministic iteration order. + + std::map<string, DataType> types(s.type_parameters.begin(), + s.type_parameters.end()); + for (const auto& pair : types) { + AttrValue attr_value; + attr_value.set_type(pair.second); + h = Hash64Combine(Hash64(pair.first), h); + h = Hash64Combine(AttrValueHash(attr_value), h); + } + + std::map<string, AttrValue> body(s.body_parameters.begin(), + s.body_parameters.end()); + for (const auto& pair : body) { + h = Hash64Combine(Hash64(pair.first), h); + h = Hash64Combine(AttrValueHash(pair.second), h); + } + + std::map<int, string> inputs(s.const_inputs.begin(), + s.const_inputs.end()); + for (const auto& pair : inputs) { + h = Hash64Combine(std::hash<int>()(pair.first), h); + h = Hash64Combine(Hash64(pair.second), h); + } + + return h; + } + }; +}; + +struct FunctionSpecialization { + string specialized_func_name; + std::unordered_set<string> const_inputs; + std::unordered_set<string> control_deps; +}; + class FunctionOptimizerContext { public: explicit FunctionOptimizerContext(RewriterConfig::Toggle opt_level, @@ -108,6 +178,16 @@ class FunctionOptimizerContext { return gtl::FindWithDefault(inlined_functions_, name, nullptr); } + const FunctionSpecialization* FindFunctionSpecialization( + const FunctionSpecializationSignature& sig) const { + return gtl::FindOrNull(specialized_functions_, sig); + } + + void AddSpecializedFunction(const FunctionSpecializationSignature& sig, + const FunctionSpecialization& specialized_func) { + specialized_functions_.emplace(sig, specialized_func); + } + private: void InitializeTrulyConstNodes(const GrapplerItem& item) { std::unordered_set<string> feed_nodes; @@ -148,6 +228,12 @@ class FunctionOptimizerContext { // Nodes that are Const and not in feed. std::unordered_map<string, const NodeDef*> truly_const_nodes_; + // Specialized functions. + std::unordered_map<FunctionSpecializationSignature, + const FunctionSpecialization, + FunctionSpecializationSignature::Hash> + specialized_functions_; + TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext); }; @@ -303,14 +389,34 @@ void RemovePushedDownConstInputs(const std::unordered_set<string>& const_inputs, for (const string& ctrl : control_deps) { if (existing_control_deps.find(ctrl) == existing_control_deps.end()) { - VLOG(3) << "Forward control dependency to function caller node: input=" - << ctrl; + VLOG(3) << "Forward control dependency: input=" << ctrl; specialized_func_node->add_input(ctrl); } } } } +Status InitializeFunctionSpecializationSignature( + const NodeDef& func_node, const FunctionDef& func, + const AttrValueMap& func_attr, const FunctionOptimizerContext& ctx, + FunctionSpecializationSignature* sig) { + sig->func_name = func.signature().name(); + + TF_RETURN_IF_ERROR( + InstantiationTypeParameters(func, func_attr, &sig->type_parameters)); + TF_RETURN_IF_ERROR( + InstantiationBodyParameters(func, func_attr, &sig->body_parameters)); + + for (int i = 0; i < func_node.input_size(); ++i) { + const string& input = func_node.input(i); + if (ctx.IsTrulyConst(input)) { + sig->const_inputs.emplace(i, input); + } + } + + return Status::OK(); +} + Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, FunctionOptimizerContext* ctx, GraphDef* optimized_graph) { @@ -320,6 +426,32 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, const std::unordered_map<string, AttrValue> func_attr( func_node.attr().begin(), func_node.attr().end()); + FunctionSpecializationSignature signature; + TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature( + func_node, func, func_attr, *ctx, &signature)); + + // Check if function was already specialized for identical context. + const FunctionSpecialization* already_specialized = + ctx->FindFunctionSpecialization(signature); + + if (already_specialized) { + VLOG(2) << "Function was already specialized in identical context: " + "specialized_name=" + << already_specialized->specialized_func_name; + + // Add a function call node for the specialized function. + NodeDef* specialized_func_node = optimized_graph->add_node(); + *specialized_func_node = func_node; + specialized_func_node->set_op(already_specialized->specialized_func_name); + + RemovePushedDownConstInputs(already_specialized->const_inputs, + already_specialized->control_deps, + specialized_func_node); + + return Status::OK(); + } + + // Add a new specialized function definition to the library. const auto& flib = ctx->function_library(); // Make a GrapplerFunctionItem and convert it back to FunctionDef after @@ -358,6 +490,10 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // Update specialized node to remove inputs for pushed down consts. RemovePushedDownConstInputs(const_inputs, control_deps, specialized_func_node); + + ctx->AddSpecializedFunction( + signature, {specialized_func_name, const_inputs, control_deps}); + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 147a264421..0aaf57e947 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -718,5 +718,147 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_PushDownConstInput) { test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); } +TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) { + using test::function::NDef; + + FunctionOptimizer optimizer(RewriterConfig::DEFAULT); + + // Mark MyMul as noinline. + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, int32}"}, + {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + (*mul_func.mutable_attr())["_noinline"].set_b(true); + std::vector<FunctionDef> function_library = {mul_func}; + + const Tensor kTwo = test::AsScalar<float>(2.0); + const Tensor kThree = test::AsScalar<float>(3.0); + + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("init", "NoOp", {}, {}, kDevice), + + // Float placeholders. + NDef("xf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("yf", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + + // Int32 placeholders. + NDef("xi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice), + NDef("yi", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice), + + // Consts. Control inputs has to be attached to specialized func calls. + NDef("two", "Const", {"^init", "^xf"}, + {{"dtype", DT_FLOAT}, {"value", kTwo}}, kDevice), + NDef("three", "Const", {"^init", "^xf"}, + {{"dtype", DT_FLOAT}, {"value", kThree}}, kDevice), + + // Specialization #1: DT_FLOAT type parameter. + NDef("mul_1", "MyMul", {"xf", "yf"}, {{"T", DT_FLOAT}}, kDevice), + NDef("mul_2", "MyMul", {"yf", "xf"}, {{"T", DT_FLOAT}}, kDevice), + + // Specialization #2: DT_INT32 type parameter. + NDef("mul_3", "MyMul", {"xi", "yi"}, {{"T", DT_INT32}}, kDevice), + + // Specialization #3: DT_FLOAT type parameter + const input kTwo. + NDef("mul_4", "MyMul", {"xf", "two"}, {{"T", DT_FLOAT}}, kDevice), + NDef("mul_5", "MyMul", {"yf", "two"}, {{"T", DT_FLOAT}}, kDevice), + + // Specialization #4: DT_FLOAT type parameter + const input kThree. + NDef("mul_6", "MyMul", {"three", "xf"}, {{"T", DT_FLOAT}}, kDevice)}, + function_library); + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Make sure that MyMul was specialized once per unique context. + EXPECT_EQ(4, output.library().function_size()); + + // And graph nodes calling specialized functions. + int count = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "mul_1" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_1", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("yf", node.input(1)); + + } else if (node.name() == "mul_2" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_1", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("yf", node.input(0)); + EXPECT_EQ("xf", node.input(1)); + + } else if (node.name() == "mul_3" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_3", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xi", node.input(0)); + EXPECT_EQ("yi", node.input(1)); + + } else if (node.name() == "mul_4" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_4", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + + } else if (node.name() == "mul_5" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_4", node.op()); + ASSERT_EQ(3, node.input_size()); + EXPECT_EQ("yf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + EXPECT_EQ("^xf", node.input(2)); + + } else if (node.name() == "mul_6" && count++) { + EXPECT_EQ("MyMul_specialized_for_mul_6", node.op()); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ("xf", node.input(0)); + EXPECT_EQ("^init", node.input(1)); + } + } + EXPECT_EQ(6, count); + + // And that graph evaluation yields the same result. + Tensor pi = test::AsScalar<float>(3.14f); + Tensor four = test::AsScalar<int32>(4); + item.fetch = {"mul_1", "mul_2", "mul_3", "mul_4", "mul_5", "mul_6"}; + item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}}; + + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + + test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); + test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]); + test::ExpectTensorEqual<int32>(tensors_expected[2], tensors[2]); + test::ExpectTensorEqual<float>(tensors_expected[3], tensors[3]); + test::ExpectTensorEqual<float>(tensors_expected[4], tensors[4]); + test::ExpectTensorEqual<float>(tensors_expected[5], tensors[5]); +} + +TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) { + using test::function::NDef; + FunctionOptimizer optimizer(RewriterConfig::DEFAULT); + DisableFunctionSpecialization(&optimizer); + auto func = test::function::XTimesTwo(); + (*func.mutable_attr())["_noinline"].set_b(true); + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"), + NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"), + NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, "/device:CPU:0")}, + // FunctionLib + { + func, + test::function::XTimesTwoInt32(), + test::function::XTimes16(), + }); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(output.library().function().size(), 1); + EXPECT_EQ(output.library().function(0).signature().name(), "XTimesTwo"); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 5adc5b9227..7d3520febc 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve, return Status::OK(); } +Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve, + GraphDef* optimized_graph) { + std::unordered_set<const NodeDef*> dead_nodes; + std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs; + // TODO(bsteiner): also rewrite switches as identity. For now we just record + // them + std::unordered_set<GraphView::OutputPort, GraphView::HashPort> + identity_switches; + + GraphView view(optimized_graph); + for (const NodeDef& node : optimized_graph->node()) { + if (!IsSwitch(node)) { + continue; + } + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + continue; + } + GraphView::InputPort ctrl_port(&node, 1); + GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port); + if (!IsConstant(*ctrl_node.node)) { + continue; + } + Tensor selector; + CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor())); + const int dead_fanout = selector.scalar<bool>()() ? 0 : 1; + GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout); + identity_switches.insert(dead); + + SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs; + for (const GraphView::InputPort& port : view.GetFanout(dead)) { + if (dead_nodes.find(port.node) == dead_nodes.end()) { + zombie_inputs.PushBack(port); + } + } + // If we encounter a single node that must be preserved in the fanout of the + // switch node we need to preserve the entire switch fanout: we therefore + // work on a local copy that only gets committed to the master copy once the + // whole fanout has been explored. + std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes; + std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs = + dead_merge_inputs; + bool found_node_to_preserve = false; + while (!found_node_to_preserve && !zombie_inputs.Empty()) { + GraphView::InputPort dead = zombie_inputs.PopBack(); + if (nodes_to_preserve.find(dead.node->name()) != + nodes_to_preserve.end()) { + found_node_to_preserve = true; + break; + } + + if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) { + continue; + } + + if (IsMerge(*dead.node)) { + const int fanout = dead.node->attr().at("N").i(); + if (fanout > 2) { + // This never happens in practice, so we'll just skip these to + // simplify the code for now. + found_node_to_preserve = true; + break; + } + GraphView::OutputPort value_index(dead.node, 1); + const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& + index_fanout = view.GetFanout(value_index); + if (!index_fanout.empty()) { + // The 2nd output (that indicates which input is propagated) is + // connected. This never happens in practice, so we'll just skip this + // case to simplify the code for now. + found_node_to_preserve = true; + break; + } + + bool fully_dead = false; + if (dead.port_id < 0) { + // If the control dependency never gets triggered the merge will also + // never get triggered. + local_dead_nodes.insert(dead.node); + fully_dead = true; + } else { + local_dead_merge_inputs[dead.node].insert(dead.port_id); + if (local_dead_merge_inputs[dead.node].size() == + dead.node->attr().at("N").i()) { + fully_dead = true; + } + if (fully_dead) { + local_dead_nodes.insert(dead.node); + for (const GraphView::InputPort& port : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(port); + } + } + } + } else { + if (local_dead_nodes.insert(dead.node).second) { + for (const GraphView::InputPort& dead_fanout : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(dead_fanout); + } + } + } + } + if (!found_node_to_preserve) { + std::swap(dead_nodes, local_dead_nodes); + std::swap(dead_merge_inputs, local_dead_merge_inputs); + } + } + + int last = optimized_graph->node_size() - 1; + for (int i = optimized_graph->node_size() - 1; i >= 0; --i) { + NodeDef* node = optimized_graph->mutable_node(i); + if (dead_nodes.find(node) != dead_nodes.end()) { + optimized_graph->mutable_node()->SwapElements(i, last); + last--; + } + } + optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size()); + + for (const auto& itr : dead_merge_inputs) { + NodeDef* dead_node = itr.first; + if (dead_nodes.find(dead_node) != dead_nodes.end()) { + // The node has been pruned since all its inputs are dead. + continue; + } + const std::set<int>& dead_inputs = itr.second; + for (int index : dead_inputs) { + dead_node->mutable_input()->DeleteSubrange(index, 1); + } + dead_node->set_op("Identity"); + dead_node->mutable_attr()->erase("N"); + } + return Status::OK(); +} + } // namespace Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (options_.enable_stack_push_removal) { TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } + if (opt_level_ == RewriterConfig::AGGRESSIVE && + options_.enable_dead_branch_removal) { + TF_RETURN_IF_ERROR( + RemoveDeadBranches(item.NodesToPreserve(), optimized_graph)); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index 764506f7c1..85b8e65543 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer { struct LoopOptimizerOptions { bool enable_loop_invariant_node_motion = false; bool enable_stack_push_removal = true; + bool enable_dead_branch_removal = true; static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { LoopOptimizerOptions options; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 10ec544424..6fd177b710 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { } } +TEST_F(LoopOptimizerTest, RemoveDeadBranches) { + Scope scope = Scope::NewRootScope(); + Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT); + + Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({})); + ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1); + Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false); + Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true); + + Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({})); + ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2); + Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false); + Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true); + + Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({})); + ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3); + Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false); + Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true); + + Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({})); + ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4); + Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false); + Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true); + + ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1}); + ops::Merge m2(scope.WithOpName("m2"), {v_in, square1}); + ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1}); + ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2}); + ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1}); + ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2), + {v_in, square1}); + ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1), + {v_in, square1}); + + ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1); + Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true); + ops::Merge m8(scope.WithOpName("m8"), {id1, id2}); + + ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1); + Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false); + Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true); + ops::Merge m9(scope.WithOpName("m9"), {id3, id4}); + + GrapplerItem item; + item.fetch.push_back("m8"); + item.fetch.push_back("id4"); + + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_CHECK_OK(status); + + for (const NodeDef& node : output.node()) { + // These nodes should have been pruned + EXPECT_NE("Square1", node.name()); + EXPECT_NE("Sqrt2", node.name()); + EXPECT_NE("m5", node.name()); + EXPECT_NE("m7", node.name()); + + if (node.name() == "m1") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + } else if (node.name() == "m2") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + } else if (node.name() == "m3") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + } else if (node.name() == "m4") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + EXPECT_EQ("sqrt2", node.input(1)); + } else if (node.name() == "m6") { + // both inputs are alive and the control dependency can get triggered + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + EXPECT_EQ("^sqrt2", node.input(2)); + } else if (node.name() == "m8") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id1", node.input(0)); + EXPECT_EQ("id2", node.input(1)); + } else if (node.name() == "m9") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id3", node.input(0)); + EXPECT_EQ("id4", node.input(1)); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 887a988af9..8247cce339 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -163,30 +163,28 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { output.library()); // Specialized and optimized functions should be added to the graph. - EXPECT_EQ(6, optimized_flib.num_functions()); + EXPECT_EQ(5, optimized_flib.num_functions()); // MyQuadratic should be specialized once: // 0. 'quadratic' node in the main graph const string optimized_0 = "MyQuadratic_specialized_for_quadratic"; // MySquare should be specialized and optimized for 3 instantiations: - // 1. 'square' node in the main graph - // 2. 'square' node in the MyQuadratic specialization - // 3. 'quadratic' node in the MyQuadratic specialization + // 1. 'square' node in the main graph + // 2. 'square' node in the MyQuadratic specialization + // 3*. 'quadratic' node in the MyQuadratic specialization + // has identical instantiation context to #2 const string optimized_1 = "MySquare_specialized_for_square"; const string optimized_2 = "MySquare_specialized_for_square_1"; - const string optimized_3 = "MySquare_specialized_for_quadratic"; const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0); const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1); const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2); - const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3); ASSERT_NE(optimized_func_0, nullptr); ASSERT_NE(optimized_func_1, nullptr); ASSERT_NE(optimized_func_2, nullptr); - ASSERT_NE(optimized_func_3, nullptr); // Graph should call optimized function. int count = 0; @@ -205,13 +203,14 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { if (node.name() == "square" && count++) { EXPECT_EQ(optimized_2, node.op()); } else if (node.name() == "quadratic" && count++) { - EXPECT_EQ(optimized_3, node.op()); + // Share specialized function with the 'square' node. + EXPECT_EQ(optimized_2, node.op()); } } EXPECT_EQ(2, count); - const std::vector<const FunctionDef*> optimized_funcs = { - optimized_func_1, optimized_func_1, optimized_func_3}; + const std::vector<const FunctionDef*> optimized_funcs = {optimized_func_1, + optimized_func_2}; // MyMul should be inlined into all optimized versions of MySquare. for (const FunctionDef* optimized_func : optimized_funcs) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b87ae05546..1c6fef59ea 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -65,7 +65,7 @@ class NodeMap { // A vector with a set. The set stores the same elements as the vector, and // quickly answers whether a value is in the vector. Duplicated elements are not // allowed for now. -template <class T> +template <class T, class Hash = std::hash<T>> class SetVector { public: // Returns false if value already existed in the set, true otherwise. @@ -91,7 +91,7 @@ class SetVector { void Reserve(int64 size) { vector_.reserve(size); } private: - std::unordered_set<T> set_; + std::unordered_set<T, Hash> set_; std::vector<T> vector_; }; diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index 79b823fa2d..34603f9869 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -417,6 +417,63 @@ bool IsParametrized(const FunctionDef& func) { return HasParametrizedType(func) || HasParametrizedBody(func); } +Status InstantiationTypeParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map<string, DataType>* type_parameters) { + if (!type_parameters->empty()) { + return errors::InvalidArgument("Type parameters output map must be empty"); + } + + GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr); + + const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) { + // Check if it's unknown and unresolved type. + if (arg.type() == DT_INVALID && + type_parameters->find(arg.type_attr()) == type_parameters->end()) { + DataType data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type)); + type_parameters->insert({arg.type_attr(), data_type}); + } + return Status::OK(); + }; + + for (const auto& input : func.signature().input_arg()) + TF_RETURN_IF_ERROR(resolve_type_attr(input)); + for (const auto& output : func.signature().output_arg()) + TF_RETURN_IF_ERROR(resolve_type_attr(output)); + + return Status::OK(); +} + +Status InstantiationBodyParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map<string, AttrValue>* body_parameters) { + if (!body_parameters->empty()) { + return errors::InvalidArgument("Body parameters output map must be empty"); + } + + for (const NodeDef& func_body_node : func.node_def()) { + for (auto& attr : func_body_node.attr()) { + const string& placeholder = attr.second.placeholder(); + + if (placeholder.empty() || + body_parameters->find(placeholder) != body_parameters->end()) { + continue; + } + + auto it = func_instantiation_attr.find(placeholder); + if (it != func_instantiation_attr.end()) { + body_parameters->emplace(placeholder, it->second); + } else { + return errors::InvalidArgument("Can't resolve placeholder: ", + placeholder); + } + } + } + + return Status::OK(); +} + Status MakeGrapplerFunctionItem(const FunctionDef& func, const AttrValueMap& func_instantiation_attr, const FunctionLibraryDefinition& flib, diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index d9d71b80eb..4641bf5252 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -191,6 +191,19 @@ bool HasParametrizedBody(const FunctionDef& func); // Check if function has parametrized type or body. bool IsParametrized(const FunctionDef& func); +// Resolve function instantiation type parameters from the attributes of the +// caller node. Return error if type can't be resolved. +Status InstantiationTypeParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map<string, DataType>* type_parameters); + +// Resolve function instantiation body parameters (values for the function body +// attr placeholders) from the attributes of the caller node. Return error if +// type can't be resolved. +Status InstantiationBodyParameters( + const FunctionDef& func, const AttrValueMap& func_instantiation_attr, + std::unordered_map<string, AttrValue>* body_parameters); + // Register GrapplerFunctionItem input arg expansion and function body outputs // in the GrapplerFunctionConnectivity. Use function library definition to // lookup function body nodes output names and ranges. @@ -205,10 +218,10 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, // Make a GrapplerFunctionItem from the function definition and function // instantiation attributes (caller node attributes). Returns error if the given // function def cannot be converted (e.g. not all attributes are defined). -Status MakeGrapplerFunctionItem( - const FunctionDef& func, - const std::unordered_map<string, AttrValue>& func_instantiation_attr, - const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item); +Status MakeGrapplerFunctionItem(const FunctionDef& func, + const AttrValueMap& func_instantiation_attr, + const FunctionLibraryDefinition& flib, + GrapplerFunctionItem* item); // Make a GrapplerFunction item from the function definition. Function must be // fully defined (no type or body parametrization). diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index fa6fec70ff..15d8437438 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -54,6 +54,44 @@ TEST_F(FunctionsTest, IsParametrized) { EXPECT_FALSE(IsParametrized(non_parametrized_func)); } +TEST_F(FunctionsTest, InstantiationParameters) { + // Function definition is invalid, only type/body parameters are important. + FunctionDef func = FunctionDefHelper::Create( + "ParametrizedFunc", + /* inputs */ + {"input1:A", "input2:B", "input3:float"}, + /* outputs */ + {"output1: A", "output2:C"}, + /* type parameters */ + {"A: {float, double}", "B: {float, int32}", "C: {float, double}"}, + /* function body*/ + {{{"output"}, "FakeOp", {"input1", "input2"}, {{"key", "$key"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"x", "cx:output:0"}, {"y", "cy:output:0"}}); + + std::unordered_map<string, AttrValue> func_instantiation_attr; + func_instantiation_attr["key"].set_s("key-value"); + func_instantiation_attr["A"].set_type(DT_FLOAT); + func_instantiation_attr["B"].set_type(DT_INT32); + func_instantiation_attr["C"].set_type(DT_DOUBLE); + + std::unordered_map<string, DataType> type_parameters; + TF_EXPECT_OK(InstantiationTypeParameters(func, func_instantiation_attr, + &type_parameters)); + + ASSERT_EQ(3, type_parameters.size()); + EXPECT_EQ(DT_FLOAT, type_parameters["A"]); + EXPECT_EQ(DT_INT32, type_parameters["B"]); + EXPECT_EQ(DT_DOUBLE, type_parameters["C"]); + + std::unordered_map<string, AttrValue> body_parameters; + TF_EXPECT_OK(InstantiationBodyParameters(func, func_instantiation_attr, + &body_parameters)); + + ASSERT_EQ(1, body_parameters.size()); + EXPECT_EQ("key-value", body_parameters["key"].s()); +} + TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { GrapplerFunctionConnectivity connectivity; diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 67ddb52d57..c608f9e1c6 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -46,18 +46,6 @@ Status MakeIteratorFromInputElement( return Status::OK(); } -IteratorContext MakeIteratorContext(OpKernelContext* ctx) { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.lib = ctx->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - return IteratorContext(params); -} - } // namespace dataset } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index e5ca71dd99..6c4191c2be 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -28,8 +28,6 @@ Status MakeIteratorFromInputElement( int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator); -IteratorContext MakeIteratorContext(OpKernelContext* ctx); - } // namespace dataset } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index a2f6c5fe2c..b6bf0ecd09 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -1051,7 +1051,7 @@ class DeserializeIteratorOp : public OpKernel { IteratorResource* iterator_resource; OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - + core::ScopedUnref unref_iterator(iterator_resource); Variant variant = ctx->input(1).scalar<Variant>()(); auto* wrapper = variant.get<IteratorStateVariant>(); OP_REQUIRES(ctx, wrapper != nullptr, diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index c9551fbf16..729b615e56 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS +#include <utility> + #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -21,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tracing.h" @@ -36,7 +39,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + graph_def_version_(ctx->graph_def_version()), + op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -59,12 +63,29 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ctx, batch_size > 0, errors::InvalidArgument("batch_size must be greater than zero.")); - int64 num_parallel_batches; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches", - &num_parallel_batches)); - OP_REQUIRES(ctx, num_parallel_batches > 0, - errors::InvalidArgument( - "num_parallel_batches must be greater than zero.")); + int64 num_parallel_calls; + switch (op_version_) { + case 1: + int64 num_parallel_batches; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches", + &num_parallel_batches)); + num_parallel_calls = num_parallel_batches * batch_size; + OP_REQUIRES(ctx, num_parallel_batches > 0, + errors::InvalidArgument( + "num_parallel_batches must be greater than zero.")); + break; + case 2: + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + break; + default: + OP_REQUIRES(ctx, false, + errors::Unimplemented("Unsupported operation version %d.", + op_version_)); + } bool drop_remainder; OP_REQUIRES_OK(ctx, @@ -74,7 +95,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create( func_, std::move(other_arguments), &captured_func)); - *output = new Dataset(ctx, input, batch_size, num_parallel_batches, + *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, func_, std::move(captured_func), &ctx->eigen_cpu_device()); } @@ -83,7 +104,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { class Dataset : public GraphDatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, - int64 num_parallel_batches, bool drop_remainder, + int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, const NameAttrList& func, @@ -92,7 +113,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { : GraphDatasetBase(ctx), input_(input), batch_size_(batch_size), - num_parallel_batches_(num_parallel_batches), + num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), @@ -128,9 +149,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); - Node* num_parallel_batches_node; + Node* num_parallel_calls_node; TF_RETURN_IF_ERROR( - b->AddScalar(num_parallel_batches_, &num_parallel_batches_node)); + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); Node* drop_remainder_node; TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); @@ -153,7 +174,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { this, {std::make_pair(0, input_graph_node), std::make_pair(2, batch_size_node), - std::make_pair(3, num_parallel_batches_node), + std::make_pair(3, num_parallel_calls_node), std::make_pair(4, drop_remainder_node)}, // Single tensor inputs. {std::make_pair(1, other_arguments)}, // Tensor list inputs. {std::make_pair("f", f), @@ -168,129 +189,54 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), input_impl_(params.dataset->input_->MakeIterator(params.prefix)), - invocation_results_(params.dataset->batch_size_ * - params.dataset->num_parallel_batches_), - batch_results_(params.dataset->num_parallel_batches_) {} + batch_results_((params.dataset->num_parallel_calls_ + + params.dataset->batch_size_ - 1) / + params.dataset->batch_size_) { + for (int i = 0; i < batch_results_.size(); ++i) { + batch_results_[i].Initialize(params.dataset->batch_size_); + } + } ~Iterator() override { - // TODO(mrry): Replace this cancellation logic with a - // CancellationManager. The syntax would be more heavyweight, - // but it would be possible to thread a cancellation manager - // through the IteratorContext to upstream, - // potentially-blocking iterators, when we add these. mutex_lock l(mu_); - if (current_batch_index_ != -1) { - for (size_t batch_index = 0; - batch_index < dataset()->num_parallel_batches_; ++batch_index) { - int64 num_elements; - WaitForBatch(batch_index, &num_elements).IgnoreError(); - // Deallocate tensors allocated for the output. - batch_results_[batch_index].output.clear(); - } + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); } } - // TODO(jsimsa): Implement and profile the following alternative design: - // - // 0. Set the number of in-flight batches and invocations independently - // (though obviously the max number of in-flight invocations must be < - // batch_size * num_parallel_batches). Maintain a current producing batch - // index and offset. - // 1. Issue invocations in order of batch and offset, as you do currently. - // 2. When an invocation finishes, increment the current producing batch - // and offset. If that invocation would start a new batch and give more - // than num_parallel_batches in-flight, block; else start the new - // invocation into that location. - // 3. When a GetNext() call arrives, block until there's a full batch. - // Before returning the batch, if the number of pending invocations is - // less than the max, issue that number of invocations. Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - - // One-time initialization. - if (current_batch_index_ == -1) { - current_batch_index_ = 0; - for (size_t i = 0; i < dataset()->num_parallel_batches_; ++i) { - StartInvocationBatch(ctx, i); - } - } - - int64 num_elements = 0; - Status status = WaitForBatch(current_batch_index_, &num_elements); - if (num_elements == 0) { - *end_of_sequence = true; - return Status::OK(); - } - if (!status.ok()) { - // Deallocate tensors allocated for the output. - batch_results_[current_batch_index_].output.clear(); - } else { - if (num_elements < dataset()->batch_size_) { - if (dataset()->drop_remainder_) { - // Deallocate tensors allocated for the output. - batch_results_[current_batch_index_].output.clear(); - *end_of_sequence = true; - return Status::OK(); - } - const std::vector<Tensor>& output = - batch_results_[current_batch_index_].output; - for (size_t i = 0; i < output.size(); ++i) { - TensorShape component_shape( - batch_results_[current_batch_index_].output[i].shape()); - component_shape.set_dim(0, num_elements); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - Tensor component(ctx->allocator(attr), output[i].dtype(), - component_shape); - TF_RETURN_IF_ERROR( - CopyPartialBatch(&component, output[i], num_elements)); - out_tensors->emplace_back(std::move(component)); - } - // Deallocate tensors allocated for the output. - batch_results_[current_batch_index_].output.clear(); - } else { - *out_tensors = - std::move(batch_results_[current_batch_index_].output); - } - *end_of_sequence = false; - } - StartInvocationBatch(ctx, current_batch_index_); - current_batch_index_ = - (current_batch_index_ + 1) % dataset()->num_parallel_batches_; - return status; + EnsureRunnerThreadStarted(ctx); + BatchResult* result = &batch_results_[ComputeIndex(input_batch_)]; + WaitForBatch(result, &l); + return ProcessBatch(ctx, result, out_tensors, end_of_sequence); } protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - if (current_batch_index_ == -1) { - // Iterator has not been used. Nothing to save. - return Status::OK(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); } - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_batch_index"), - current_batch_index_)); + CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - // Wait for the map_fn dispatches made in `InvokeFunctionLocked` to - // finish. This may delay saving a checkpoint by a bit but keeps the - // code clean and also saves us from checkpointing the state of the - // `BlockingCounter`. - std::vector<int64> num_elements(batch_results_.size()); - for (size_t i = 0; i < batch_results_.size(); i++) { - WaitForBatch(i, &num_elements[i]).IgnoreError(); - } - - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("invocation_results_size"), invocation_results_.size())); - for (size_t i = 0; i < invocation_results_.size(); ++i) { - TF_RETURN_IF_ERROR(WriteInvocationResultLocked(writer, i)); - } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("call_counter"), call_counter_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_batch"), input_batch_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("output_batch"), output_batch_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"), batch_results_.size())); for (size_t i = 0; i < batch_results_.size(); ++i) { - TF_RETURN_IF_ERROR( - WriteBatchResultLocked(writer, i, num_elements[i])); + TF_RETURN_IF_ERROR(WriteBatchResult(writer, i)); } return Status::OK(); } @@ -298,70 +244,136 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - if (!reader->Contains(full_name("current_batch_index"))) { - // Iterator was never used so nothing to restore. - return Status::OK(); - } - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("current_batch_index"), &temp)); - current_batch_index_ = static_cast<int32>(temp); - if (current_batch_index_ != temp) { - return errors::Internal("Invalid value for current_batch_index ", - temp); - } - } TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - size_t invocation_results_size; - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("invocation_results_size"), &temp)); - invocation_results_size = static_cast<size_t>(temp); - if (invocation_results_size != temp) { - return errors::Internal( - "Invalid value for invocation_results_size ", temp); - } - } - CHECK_EQ(invocation_results_.size(), invocation_results_size); - for (size_t i = 0; i < invocation_results_size; ++i) { - TF_RETURN_IF_ERROR(ReadInvocationResultLocked(reader, i)); - } - size_t batch_results_size; - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("batch_results_size"), &temp)); - batch_results_size = static_cast<size_t>(temp); - if (batch_results_size != temp) { - return errors::Internal("Invalid value for batch_results_size ", - temp); - } - } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("call_counter"), &call_counter_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("input_batch"), &input_batch_)); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("output_batch"), &output_batch_)); + int64 batch_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"), + &batch_results_size)); CHECK_EQ(batch_results_.size(), batch_results_size); - for (size_t i = 0; i < batch_results_size; ++i) { - TF_RETURN_IF_ERROR(ReadBatchResultLocked(ctx, reader, i)); + for (int i = 0; i < batch_results_size; ++i) { + TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); } return Status::OK(); } private: struct BatchResult { - mutex mu ACQUIRED_AFTER(mu_); - bool output_allocated GUARDED_BY(mu); + mutex mu; + bool end_of_input GUARDED_BY(mu); + int64 num_elements GUARDED_BY(mu); std::vector<Tensor> output; - std::unique_ptr<BlockingCounter> counter; + bool output_allocated GUARDED_BY(mu); + Status status GUARDED_BY(mu); + // Used for coordination between the main thread and the callback + // threads. In particular, the main thread will wait for the value + // of `num_calls` to reach zero before processing the batch result. + condition_variable cond_var; // access guarded by owner's mutex + // Counts the number of outstanding calls for this batch. + int64 num_calls; // access guarded by owner's mutex + + void Initialize(int64 batch_size) { + mutex_lock l(mu); + end_of_input = false; + num_calls = batch_size; + num_elements = 0; + output_allocated = false; + status = Status::OK(); + } + + void UpdateStatus(const Status& s) { + mutex_lock l(mu); + status.Update(s); + } }; - struct InvocationResult { - Status status; + void Callback(const std::shared_ptr<IteratorContext>& ctx, + BatchResult* result, std::vector<Tensor>* return_values, + int64 offset, const Status& status) { + std::unique_ptr<std::vector<Tensor>> cleanup_retvals(return_values); + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does not " + "match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + { + mutex_lock l(mu_); + CallCompleted(result); + } + } + + void CallCompleted(BatchResult* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + num_calls_--; + cond_var_.notify_all(); + result->num_calls--; + result->cond_var.notify_all(); + } + + void CallFunction(std::shared_ptr<IteratorContext> ctx, + BatchResult* result, int64 offset) { + // Get the next input element. + std::vector<Tensor> input_element; bool end_of_input; - std::vector<Tensor> return_values; - }; + Status status = + input_impl_->GetNext(ctx.get(), &input_element, &end_of_input); + { + mutex_lock l(mu_); + mutex_lock l2(result->mu); + result->end_of_input = result->end_of_input || end_of_input; + result->status.Update(status); + if (result->end_of_input || !result->status.ok()) { + CallCompleted(result); + return; + } + } - int64 ComputeInvocationIndex(int64 batch_index, int64 offset) { - return batch_index * dataset()->batch_size_ + offset; + // Call `captured_func_(input_element)`, using `Callback` to store the + // result in `result`. + (*ctx->runner())(std::bind( + [this, result, offset](std::shared_ptr<IteratorContext> ctx, + std::vector<Tensor> input_element) { + std::vector<Tensor>* return_values = new std::vector<Tensor>(); + dataset()->captured_func_->RunAsync( + ctx.get(), std::move(input_element), return_values, + [this, ctx, result, return_values, offset](Status status) { + Callback(ctx, result, return_values, offset, status); + }); + }, + ctx, std::move(input_element))); + } + + int64 ComputeIndex(int64 n) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return n % batch_results_.size(); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -387,253 +399,140 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - void EnsureOutputAllocated(IteratorContext* ctx, - BatchResult* batch_result, - const std::vector<Tensor>& return_values) { - mutex_lock l(batch_result->mu); - if (batch_result->output_allocated) { + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + std::bind(&Iterator::RunnerThread, this, ctx_copy))); + } + } + + void EnsureOutputAllocated(const std::shared_ptr<IteratorContext>& ctx, + BatchResult* result, + const std::vector<Tensor>* return_values) { + mutex_lock l(result->mu); + if (result->output_allocated) { return; } - const size_t num_components = return_values.size(); + const size_t num_components = return_values->size(); for (size_t i = 0; i < num_components; ++i) { TensorShape component_shape({dataset()->batch_size_}); - component_shape.AppendShape(return_values[i].shape()); + component_shape.AppendShape(return_values->at(i).shape()); AllocatorAttributes attr; attr.set_gpu_compatible(true); - Tensor component(ctx->allocator(attr), return_values[i].dtype(), + Tensor component(ctx->allocator(attr), return_values->at(i).dtype(), component_shape); - batch_result->output.emplace_back(std::move(component)); + result->output.emplace_back(std::move(component)); } - batch_result->output_allocated = true; + result->output_allocated = true; } - void InvokeFunctionLocked(IteratorContext* ctx, int64 batch_index, - int64 offset) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - size_t index = ComputeInvocationIndex(batch_index, offset); - InvocationResult* result = &invocation_results_[index]; - BatchResult* batch_result = &batch_results_[batch_index]; - - // Get the next input element. - std::vector<Tensor> input_element; - result->status = - input_impl_->GetNext(ctx, &input_element, &result->end_of_input); - if (result->end_of_input || !result->status.ok()) { - batch_result->counter->DecrementCount(); - return; + Status ProcessBatch(IteratorContext* ctx, BatchResult* result, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto cleanup = + gtl::MakeCleanup([this, result]() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->Initialize(dataset()->batch_size_); + input_batch_++; + }); + mutex_lock l(result->mu); + if (result->num_elements == 0) { + *end_of_sequence = true; + return Status::OK(); } - // Call `captured_func_(input_element)`, store the result in - // `result->return_values`, and notify `batch_result->counter` - // to unblock a consumer. - (*ctx->runner())(std::bind( - [this, result, batch_result, offset]( - IteratorContext* ctx, std::vector<Tensor> input_element) { - dataset()->captured_func_->RunAsync( - ctx, std::move(input_element), &result->return_values, - [this, ctx, result, batch_result, offset](Status ret_status) { - result->status.Update(ret_status); - if (ret_status.ok()) { - EnsureOutputAllocated(ctx, batch_result, - result->return_values); - const size_t num_components = - result->return_values.size(); - for (size_t i = 0; i < num_components; ++i) { - const Tensor& tensor = result->return_values[i]; - Tensor* batch = &(batch_result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->status.Update(errors::InvalidArgument( - "Cannot add tensor to the batch: number of " - "elements does not match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that - // allows us to move `tensor` where possible, to speed - // up string tensor batching. - Status copy_status = - ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->status.Update(copy_status); - break; - } - } - } - delete ctx; - // NOTE(mrry): We clear the return values here to release - // any memory associated with them and to paralellize the - // destruction of the tensors (which can be surprisingly - // expensive for map functions with large numbers of return - // values). - result->return_values.clear(); - batch_result->counter->DecrementCount(); - }); - }, - new IteratorContext(*ctx), std::move(input_element))); - } - - void StartInvocationBatch(IteratorContext* ctx, int64 batch_index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Start")); - // Initialize batch result. - { - mutex_lock l(batch_results_[batch_index].mu); - batch_results_[batch_index].output_allocated = false; - batch_results_[batch_index].counter.reset( - new BlockingCounter(dataset()->batch_size_)); - } - // Initialize invocation results. - for (size_t i = 0; i < dataset()->batch_size_; ++i) { - size_t index = ComputeInvocationIndex(batch_index, i); - InvocationResult* result = &invocation_results_[index]; - // Reset the state of `result`; `result->return_values` was cleared - // when the previous invocation completed. - result->end_of_input = false; - result->status = Status::OK(); - } - // Start individual invocations. - for (size_t i = 0; i < dataset()->batch_size_; ++i) { - InvokeFunctionLocked(ctx, batch_index, i); + if (!result->status.ok()) { + // Deallocate tensors allocated for the output. + result->output.clear(); + } else { + if (result->num_elements < dataset()->batch_size_) { + if (dataset()->drop_remainder_) { + // Deallocate tensors allocated for the output. + result->output.clear(); + *end_of_sequence = true; + return Status::OK(); + } + const std::vector<Tensor>& output = result->output; + for (size_t i = 0; i < output.size(); ++i) { + TensorShape component_shape(result->output[i].shape()); + component_shape.set_dim(0, result->num_elements); + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + Tensor component(ctx->allocator(attr), output[i].dtype(), + component_shape); + TF_RETURN_IF_ERROR(CopyPartialBatch(&component, output[i], + result->num_elements)); + out_tensors->emplace_back(std::move(component)); + } + // Deallocate tensors allocated for the output. + result->output.clear(); + } else { + *out_tensors = std::move(result->output); + } + *end_of_sequence = false; } + cond_var_.notify_all(); + return result->status; } - Status WaitForBatch(int64 batch_index, int64* num_elements) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - tracing::ScopedActivity activity(strings::StrCat(prefix(), "::Wait")); - batch_results_[batch_index].counter->Wait(); - Status status = Status::OK(); - for (size_t i = 0; i < dataset()->batch_size_; ++i, ++*num_elements) { - size_t index = ComputeInvocationIndex(batch_index, i); - InvocationResult* result = &invocation_results_[index]; - if (result->end_of_input) { - VLOG(3) << "end of input encountered at element[" << i << "]: "; - return Status::OK(); - } - if (!result->status.ok()) { - VLOG(3) << "failed to process element[" << i - << "]: " << result->status; - status.Update(result->status); + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + mutex_lock l(mu_); + while (true) { + while (!cancelled_ && + (num_calls_ == dataset()->num_parallel_calls_ || + (output_batch_ - input_batch_ == batch_results_.size()))) { + cond_var_.wait(l); } - } - return status; - } - Status WriteInvocationResultLocked(IteratorStateWriter* writer, - size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const InvocationResult& result = invocation_results_[index]; - string prefix = strings::StrCat("invocation_results_", index); - TF_RETURN_IF_ERROR(WriteStatusLocked( - writer, full_name(strings::StrCat(prefix, "_status")), - result.status)); - if (result.end_of_input) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_end_of_input")), "")); - } - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_return_values_size")), - result.return_values.size())); - for (size_t i = 0; i < result.return_values.size(); i++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(prefix, "_return_values_", i)), - result.return_values[i])); - } - return Status::OK(); - } + if (cancelled_) { + return; + } - Status ReadInvocationResultLocked(IteratorStateReader* reader, - size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - InvocationResult* result = &invocation_results_[index]; - string prefix = strings::StrCat("invocation_results_", index); - TF_RETURN_IF_ERROR(ReadStatusLocked( - reader, full_name(strings::StrCat(prefix, "_status")), - &result->status)); - result->end_of_input = reader->Contains( - full_name(strings::StrCat(prefix, "_end_of_input"))); - size_t return_values_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_return_values_size")), - &temp)); - return_values_size = static_cast<size_t>(temp); - if (temp != return_values_size) { - return errors::Internal("Invalid value for return_values_size ", - return_values_size); + while (num_calls_ < dataset()->num_parallel_calls_ && + (output_batch_ - input_batch_ < batch_results_.size())) { + BatchResult* result = &batch_results_[ComputeIndex(output_batch_)]; + int64 offset = call_counter_++ % dataset()->batch_size_; + num_calls_++; + mu_.unlock(); + CallFunction(ctx, result, offset); + mu_.lock(); + if (offset + 1 == dataset()->batch_size_) { + // Done scheduling calls for the current batch. + output_batch_++; + } } } - result->return_values.reserve(return_values_size); - for (size_t i = 0; i < return_values_size; i++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat(prefix, "_return_values_", i)), - &result->return_values.back())); - } - return Status::OK(); } - Status WriteBatchResultLocked(IteratorStateWriter* writer, size_t index, - int64 num_elements) + void WaitForBatch(BatchResult* result, mutex_lock* l) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - const BatchResult& result = batch_results_[index]; - string prefix = strings::StrCat("batch_results_", index); - { - mutex_lock l(batch_results_[index].mu); - if (result.output_allocated) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_output_allocated")), "")); - } + while (result->num_calls > 0) { + result->cond_var.wait(*l); } - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat(prefix, "_output_size")), - result.output.size())); - for (size_t i = 0; i < result.output.size(); i++) { - // If the batch is not full, we only store the first - // `num_elements` values. The rest of the batch tensor is - // *uninitialized* and accessing that will raise msan errors. - if (num_elements < dataset()->batch_size_) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(prefix, "_output_", i)), - result.output[i].Slice(0, num_elements))); - } else { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat(prefix, "_output_", i)), - result.output[i])); - } - } - return Status::OK(); } - Status ReadBatchResultLocked(IteratorContext* ctx, - IteratorStateReader* reader, size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, + size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) { BatchResult* result = &batch_results_[index]; string prefix = strings::StrCat("batch_results_", index); - { - mutex_lock l(batch_results_[index].mu); - result->output_allocated = reader->Contains( - full_name(strings::StrCat(prefix, "_output_allocated"))); - // Simulate that the batch was fully generated. - batch_results_[index].counter.reset(new BlockingCounter(0)); - } - size_t output_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat(prefix, "_output_size")), &temp)); - output_size = static_cast<size_t>(temp); - if (temp != output_size) { - return errors::Internal("Invalid value for output_size ", - output_size); - } - } + mutex_lock l(result->mu); + result->end_of_input = reader->Contains( + full_name(strings::StrCat(prefix, "_end_of_input"))); + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(strings::StrCat(prefix, "_num_calls")), + &result->num_calls)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(prefix, "_num_elements")), + &result->num_elements)); + result->output_allocated = reader->Contains( + full_name(strings::StrCat(prefix, "_output_allocated"))); + int64 output_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat(prefix, "_output_size")), &output_size)); result->output.reserve(output_size); - for (size_t i = 0; i < output_size; i++) { + for (int i = 0; i < output_size; i++) { Tensor t; TF_RETURN_IF_ERROR(reader->ReadTensor( full_name(strings::StrCat(prefix, "_output_", i)), &t)); @@ -653,25 +552,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { result->output.emplace_back(std::move(t)); } } + TF_RETURN_IF_ERROR(ReadStatus( + reader, strings::StrCat(prefix, "_status"), &result->status)); return Status::OK(); } - Status WriteStatusLocked(IteratorStateWriter* writer, - const string& prefix, const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), - static_cast<int64>(status.code()))); - if (!status.ok()) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")), - status.error_message())); - } - return Status::OK(); - } - - Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status ReadStatus(IteratorStateReader* reader, const string& prefix, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat(prefix, "_code")), &code_int)); @@ -687,17 +574,89 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } return Status::OK(); } + + Status WriteBatchResult(IteratorStateWriter* writer, size_t index) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + BatchResult* result = &batch_results_[index]; + string prefix = strings::StrCat("batch_results_", index); + mutex_lock l(result->mu); + if (result->end_of_input) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(prefix, "_end_of_input")), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(prefix, "_num_calls")), + result->num_calls)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(prefix, "_num_elements")), + result->num_elements)); + if (result->output_allocated) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(prefix, "_output_allocated")), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat(prefix, "_output_size")), + result->output.size())); + for (int i = 0; i < result->output.size(); i++) { + // If the batch is not full, we only store the first `num_elements` + // values. The rest of the batch tensor is *uninitialized* and + // accessing that will raise msan errors. + if (result->num_elements < dataset()->batch_size_) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat(prefix, "_output_", i)), + result->output[i].Slice(0, result->num_elements))); + } else { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat(prefix, "_output_", i)), + result->output[i])); + } + } + TF_RETURN_IF_ERROR(WriteStatus( + writer, strings::StrCat(prefix, "_status"), result->status)); + return Status::OK(); + } + + Status WriteStatus(IteratorStateWriter* writer, const string& prefix, + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), + static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")), + status.error_message())); + } + return Status::OK(); + } + mutex mu_; - int32 current_batch_index_ GUARDED_BY(mu_) = -1; - const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); - std::vector<InvocationResult> invocation_results_ GUARDED_BY(mu_); + // Used for coordination between the main thread, the runner thread, and + // the callback threads. In particular, the runner thread should only + // schedule new calls when the number of in-flight calls is less than the + // user specified level of parallelism and there are slots available in + // the `batch_results_` buffer. + condition_variable cond_var_; + // Counts the number of outstanding calls for this batch. + int64 num_calls_ GUARDED_BY(mu_) = 0; + // Counts the total number of calls. + int64 call_counter_ GUARDED_BY(mu_) = 0; + const std::unique_ptr<IteratorBase> input_impl_; + // Identifies the next batch to be read by the caller. + int64 input_batch_ GUARDED_BY(mu_) = 0; + // Identifies the next batch to create. + int64 output_batch_ GUARDED_BY(mu_) = 0; + // Circular buffer for storing the (intermediate) batch results. When + // using `input_batch_` and `output_batch_` to index into the buffer, + // their value should be interpreted modulo the size of the buffer. std::vector<BatchResult> batch_results_ GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_) = false; }; const DatasetBase* const input_; const NameAttrList func_; const int64 batch_size_; - const int64 num_parallel_batches_; + const int64 num_parallel_calls_; const bool drop_remainder_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; @@ -707,6 +666,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { }; const int graph_def_version_; + const int op_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList func_; @@ -715,6 +675,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU), MapAndBatchDatasetOp); +REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), + MapAndBatchDatasetOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index a8bcc7f7dc..03cc414905 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -703,6 +703,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU); REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); +REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate", + scatter_op::UpdateOp::ASSIGN); REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate", scatter_op::UpdateOp::ASSIGN); @@ -728,6 +730,13 @@ REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") .Device(DEVICE_GPU) .HostMemory("resource") + .TypeConstraint<bool>("dtype") + .TypeConstraint<int32>("Tindices"), + ResourceScatterUpdateOp<GPUDevice, bool, int32, + scatter_op::UpdateOp::ASSIGN>) +REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate") + .Device(DEVICE_GPU) + .HostMemory("resource") .HostMemory("indices") .TypeConstraint<Variant>("dtype") .TypeConstraint<int64>("Tindices"), diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc index 59911bf0d2..bdc878594a 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc @@ -42,6 +42,8 @@ typedef Eigen::GpuDevice GPUDevice; DEFINE_GPU_SPECS(float); DEFINE_GPU_SPECS(double); +DEFINE_GPU_SPECS_OP(bool, int32, scatter_op::UpdateOp::ASSIGN); +DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN); // TODO(b/27222123): The following fails to compile due to lack of support for // fp16. // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 3db00d8180..6880ceb505 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -28130,6 +28130,54 @@ op { } } op { + name: "MapAndBatchDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "MapClear" attr { name: "capacity" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 73174c184c..576946eddd 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -208,6 +208,19 @@ REGISTER_OP("MapAndBatchDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("MapAndBatchDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_calls: int64") + .Input("drop_remainder: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("PrefetchDataset") .Input("input_dataset: variant") .Input("buffer_size: int64") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 7156440b46..d741598b19 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -13940,6 +13940,54 @@ op { } } op { + name: "MapAndBatchDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "MapClear" attr { name: "capacity" diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 107c38114b..f6e09ef094 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -335,6 +335,8 @@ def tf_proto_library_cc(name, srcs = [], has_services = None, name = cc_name, deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] + if_static([name + "_cc_impl"]), + testonly = testonly, + visibility = visibility, ) native.cc_library( name = cc_name + "_impl", @@ -378,8 +380,10 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[], ) native.py_library( name = py_name, - deps = py_deps + ["@protobuf_archive//:protobuf_python"]) - + deps = py_deps + ["@protobuf_archive//:protobuf_python"], + testonly = testonly, + visibility = visibility, + ) return py_proto_library( diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h index a12d92795e..89e57d58a0 100644 --- a/tensorflow/core/platform/default/mutex.h +++ b/tensorflow/core/platform/default/mutex.h @@ -77,9 +77,7 @@ class SCOPED_LOCKABLE mutex_lock { // Manually nulls out the source to prevent double-free. // (std::move does not null the source pointer by default.) - explicit mutex_lock(mutex_lock&& ml) noexcept : mu_(ml.mu_) { - ml.mu_ = nullptr; - } + mutex_lock(mutex_lock&& ml) noexcept : mu_(ml.mu_) { ml.mu_ = nullptr; } ~mutex_lock() UNLOCK_FUNCTION() { if (mu_ != nullptr) { mu_->unlock(); diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md index 61edba04b4..3322004189 100644 --- a/tensorflow/docs_src/deploy/index.md +++ b/tensorflow/docs_src/deploy/index.md @@ -15,3 +15,7 @@ the following documents: out-of-the-box integration with TensorFlow models. [Source code for TensorFlow Serving](https://github.com/tensorflow/serving) is available on GitHub. + +[TensorFlow Extended (TFX)](/tfx) is an end-to-end machine learning platform for +TensorFlow. Implemented at Google, we've open sourced some TFX libraries with the +rest of the system to come. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index f530fe1206..21e4c71a60 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1049,8 +1049,8 @@ For a more intuitive description, see the "Informal Description" section below. : : : from. : |`gather_indices` | `ComputationDataHandle` | Tensor containing the starting | : : : indices of the slices we're : -: : : we're stitching together into : -: : : the output tensor. : +: : : stitching together into the : +: : : output tensor. : |`index_vector_dim` | `int64` | The dimension in | : : : `gather_indices` that contains : : : : the starting indices. : diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/programmers_guide/embedding.md index d5703e0737..8a98367dfb 100644 --- a/tensorflow/docs_src/programmers_guide/embedding.md +++ b/tensorflow/docs_src/programmers_guide/embedding.md @@ -238,7 +238,7 @@ row doesn't have to be filled, as shown below. </tr> </table> -Follow [this link]("https://www.tensorflow.org/images/embedding-mnist.mp4" ) +Follow [this link](https://www.tensorflow.org/images/embedding-mnist.mp4) to see a fun example of thumbnail images in the Embedding Projector. diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 2f1be51ada..70a271bd2e 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2544,6 +2544,72 @@ func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values return op.Output(0) } +// Reverses specific dimensions of a tensor. +// +// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions +// of `tensor`, this operation reverses each dimension i of `tensor` where +// `dims[i]` is `True`. +// +// `tensor` can have up to 8 dimensions. The number of dimensions +// of `tensor` must equal the number of elements in `dims`. In other words: +// +// `rank(tensor) = size(dims)` +// +// For example: +// +// ``` +// # tensor 't' is [[[[ 0, 1, 2, 3], +// # [ 4, 5, 6, 7], +// # [ 8, 9, 10, 11]], +// # [[12, 13, 14, 15], +// # [16, 17, 18, 19], +// # [20, 21, 22, 23]]]] +// # tensor 't' shape is [1, 2, 3, 4] +// +// # 'dims' is [False, False, False, True] +// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], +// [ 7, 6, 5, 4], +// [ 11, 10, 9, 8]], +// [[15, 14, 13, 12], +// [19, 18, 17, 16], +// [23, 22, 21, 20]]]] +// +// # 'dims' is [False, True, False, False] +// reverse(t, dims) ==> [[[[12, 13, 14, 15], +// [16, 17, 18, 19], +// [20, 21, 22, 23] +// [[ 0, 1, 2, 3], +// [ 4, 5, 6, 7], +// [ 8, 9, 10, 11]]]] +// +// # 'dims' is [False, False, True, False] +// reverse(t, dims) ==> [[[[8, 9, 10, 11], +// [4, 5, 6, 7], +// [0, 1, 2, 3]] +// [[20, 21, 22, 23], +// [16, 17, 18, 19], +// [12, 13, 14, 15]]]] +// ``` +// +// Arguments: +// tensor: Up to 8-D. +// dims: 1-D. The dimensions to reverse. +// +// Returns The same shape as `tensor`. +func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Reverse", + Input: []tf.Input{ + tensor, dims, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Clips tensor values to a specified min and max. // // Given a tensor `t`, this operation returns a tensor of the same type and @@ -2796,71 +2862,6 @@ func Asin(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// SparseToDenseAttr is an optional argument to SparseToDense. -type SparseToDenseAttr func(optionalAttr) - -// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. -// -// value: If true, indices are checked to make sure they are sorted in -// lexicographic order and that there are no repeats. -// If not specified, defaults to true -func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { - return func(m optionalAttr) { - m["validate_indices"] = value - } -} - -// Converts a sparse representation into a dense tensor. -// -// Builds an array `dense` with shape `output_shape` such that -// -// ``` -// # If sparse_indices is scalar -// dense[i] = (i == sparse_indices ? sparse_values : default_value) -// -// # If sparse_indices is a vector, then for each i -// dense[sparse_indices[i]] = sparse_values[i] -// -// # If sparse_indices is an n by d matrix, then for each i in [0, n) -// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] -// ``` -// -// All other values in `dense` are set to `default_value`. If `sparse_values` is a -// scalar, all sparse indices are set to this single value. -// -// Indices should be sorted in lexicographic order, and indices must not -// contain any repeats. If `validate_indices` is true, these properties -// are checked during execution. -// -// Arguments: -// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete -// index where `sparse_values[i]` will be placed. -// output_shape: 1-D. Shape of the dense output tensor. -// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, -// or a scalar value to be used for all sparse indices. -// default_value: Scalar value to set for indices not specified in -// `sparse_indices`. -// -// Returns Dense output tensor of shape `output_shape`. -func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SparseToDense", - Input: []tf.Input{ - sparse_indices, output_shape, sparse_values, default_value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the sum along sparse segments of a tensor. // // Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of @@ -6469,72 +6470,6 @@ func SparseFillEmptyRows(scope *Scope, indices tf.Output, values tf.Output, dens return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Reverses specific dimensions of a tensor. -// -// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions -// of `tensor`, this operation reverses each dimension i of `tensor` where -// `dims[i]` is `True`. -// -// `tensor` can have up to 8 dimensions. The number of dimensions -// of `tensor` must equal the number of elements in `dims`. In other words: -// -// `rank(tensor) = size(dims)` -// -// For example: -// -// ``` -// # tensor 't' is [[[[ 0, 1, 2, 3], -// # [ 4, 5, 6, 7], -// # [ 8, 9, 10, 11]], -// # [[12, 13, 14, 15], -// # [16, 17, 18, 19], -// # [20, 21, 22, 23]]]] -// # tensor 't' shape is [1, 2, 3, 4] -// -// # 'dims' is [False, False, False, True] -// reverse(t, dims) ==> [[[[ 3, 2, 1, 0], -// [ 7, 6, 5, 4], -// [ 11, 10, 9, 8]], -// [[15, 14, 13, 12], -// [19, 18, 17, 16], -// [23, 22, 21, 20]]]] -// -// # 'dims' is [False, True, False, False] -// reverse(t, dims) ==> [[[[12, 13, 14, 15], -// [16, 17, 18, 19], -// [20, 21, 22, 23] -// [[ 0, 1, 2, 3], -// [ 4, 5, 6, 7], -// [ 8, 9, 10, 11]]]] -// -// # 'dims' is [False, False, True, False] -// reverse(t, dims) ==> [[[[8, 9, 10, 11], -// [4, 5, 6, 7], -// [0, 1, 2, 3]] -// [[20, 21, 22, 23], -// [16, 17, 18, 19], -// [12, 13, 14, 15]]]] -// ``` -// -// Arguments: -// tensor: Up to 8-D. -// dims: 1-D. The dimensions to reverse. -// -// Returns The same shape as `tensor`. -func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Reverse", - Input: []tf.Input{ - tensor, dims, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // BiasAddGradAttr is an optional argument to BiasAddGrad. type BiasAddGradAttr func(optionalAttr) @@ -24884,6 +24819,71 @@ func DecodeJSONExample(scope *Scope, json_examples tf.Output) (binary_examples t return op.Output(0) } +// SparseToDenseAttr is an optional argument to SparseToDense. +type SparseToDenseAttr func(optionalAttr) + +// SparseToDenseValidateIndices sets the optional validate_indices attribute to value. +// +// value: If true, indices are checked to make sure they are sorted in +// lexicographic order and that there are no repeats. +// If not specified, defaults to true +func SparseToDenseValidateIndices(value bool) SparseToDenseAttr { + return func(m optionalAttr) { + m["validate_indices"] = value + } +} + +// Converts a sparse representation into a dense tensor. +// +// Builds an array `dense` with shape `output_shape` such that +// +// ``` +// # If sparse_indices is scalar +// dense[i] = (i == sparse_indices ? sparse_values : default_value) +// +// # If sparse_indices is a vector, then for each i +// dense[sparse_indices[i]] = sparse_values[i] +// +// # If sparse_indices is an n by d matrix, then for each i in [0, n) +// dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +// ``` +// +// All other values in `dense` are set to `default_value`. If `sparse_values` is a +// scalar, all sparse indices are set to this single value. +// +// Indices should be sorted in lexicographic order, and indices must not +// contain any repeats. If `validate_indices` is true, these properties +// are checked during execution. +// +// Arguments: +// sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete +// index where `sparse_values[i]` will be placed. +// output_shape: 1-D. Shape of the dense output tensor. +// sparse_values: 1-D. Values corresponding to each row of `sparse_indices`, +// or a scalar value to be used for all sparse indices. +// default_value: Scalar value to set for indices not specified in +// `sparse_indices`. +// +// Returns Dense output tensor of shape `output_shape`. +func SparseToDense(scope *Scope, sparse_indices tf.Output, output_shape tf.Output, sparse_values tf.Output, default_value tf.Output, optional ...SparseToDenseAttr) (dense tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseToDense", + Input: []tf.Input{ + sparse_indices, output_shape, sparse_values, default_value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. // // The `input` tensor has shape `[batch, in_height, in_width, depth]` and the diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 087b89b125..a865e8ca75 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1762,6 +1762,7 @@ py_library( ":logging_ops_gen", ":math_ops", ":platform", + ":resource_variable_ops_gen", ":sparse_tensor", ":tensor_array_ops", ":tf_should_use", @@ -4134,7 +4135,7 @@ cuda_py_test( py_test( name = "saver_large_variable_test", - size = "small", + size = "medium", srcs = ["training/saver_large_variable_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 00090b21fe..7cbaae46b4 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -140,7 +140,7 @@ def main(_): # Make predictions, using tfdbg hook. predict_results = classifier.predict(test_input_fn, hooks=hooks) - print("A prediction result: %s" % predict_results.next()) + print("A prediction result: %s" % next(predict_results)) if __name__ == "__main__": diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index d04b004451..967c128280 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -358,6 +358,8 @@ def gradients_function(f, params=None): assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 ``` + Note that only tensors with real or complex dtypes are differentiable. + Args: f: function to be differentiated. If `f` returns a scalar, this scalar will be differentiated. If `f` returns a tensor or list of tensors, by default @@ -700,6 +702,9 @@ class GradientTape(object): dz_dx = g.gradient(z, x) # 108.0 (4*x^3 at x = 3) dy_dx = g.gradient(y, x) # 6.0 del g # Drop the reference to the tape + ``` + + Note that only tensors with real or complex dtypes are differentiable. """ def __init__(self, persistent=False): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 8d9959fe20..be674487f1 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -124,6 +124,14 @@ class BackpropTest(test.TestCase): grad_fn = backprop.gradients_function(f) self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) + def testGradientInteger(self): + + def f(x): + return x + x + + int_tensor = constant_op.constant(1) + self.assertEqual(backprop.gradients_function(f)(int_tensor)[0], None) + def testErrors(self): @custom_gradient.custom_gradient @@ -753,7 +761,7 @@ class BackpropTest(test.TestCase): return result, grad x = resource_variable_ops.ResourceVariable( - initial_value=3, name='X.' + self.id()) + initial_value=3., name='X.' + self.id()) def f(): return my_square(x) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 741bd2ac9c..b478b6b0db 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -23,6 +23,7 @@ import collections import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -102,13 +103,15 @@ class CapturingGraph(ops.Graph): def clear_resource_control_flow_state(self): self._last_op_using_resource_tensor = {} - def maybe_capture_tensor(self, tensor): + def capture(self, tensor, name=None): if isinstance(tensor, ops.EagerTensor): - return capture_value( - self.captures, tensor, tensor.dtype, str(ops.uid())) + if name is None: + name = str(ops.uid()) + return capture_value(self.captures, tensor, tensor.dtype, name) if tensor.graph is not self: - return capture_value( - self.captures, tensor, tensor.dtype, tensor.op.name) + if name is None: + name = tensor.op.name + return capture_value(self.captures, tensor, tensor.dtype, name) return tensor def create_op( @@ -126,7 +129,7 @@ class CapturingGraph(ops.Graph): # forward the resources such as Identity and Switch can cause serialization # to fail. for i, inp in enumerate(inputs): - inputs[i] = self.maybe_capture_tensor(inp) + inputs[i] = self.capture(inp) return super(CapturingGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device) @@ -225,7 +228,7 @@ def _inference_name(n): class _EagerDefinedFunction(object): """Function object with the interface of tf _DefinedFunction.""" - def __init__(self, name, graph, operations, inputs, outputs): + def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: @@ -235,6 +238,7 @@ class _EagerDefinedFunction(object): which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function + attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access @@ -246,6 +250,14 @@ class _EagerDefinedFunction(object): [], None, compat.as_str("")) + + for name, attr_value in attrs.items(): + serialized = attr_value.SerializeToString() + # TODO(iga): this creates and deletes a new TF_Status for every attr. + # It might be worth creating a convenient way to re-use status. + pywrap_tensorflow.TF_FunctionSetAttrValueProto( + fn, compat.as_str(name), serialized) + # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: @@ -287,25 +299,6 @@ def _flatten(sequence): class GraphModeFunction(object): """Callable object representing a graph-mode function. - - Args: - name: str the name of the created function - input_placeholders: list of placeholder values (tensors) to feed when - calling the wrapped function. - extra_inputs: Tensor inputs this function definition closed over which - are passed as arguments. Need to track so gradients are supported - correctly. - graph: the Graph from which the operations will be pulled. Used as - a context when computing gradients. - operations: the subset of Operations in the graph used in the function - definition. - outputs: a flat list of the Tensors in the graph used as outputs to the - function - func_outputs: a possibly nested python object which will be returned by - this function. The Tensors in this structure will be replaced by their - corresponding values in outputs. - output_shapes: List of shapes of all tensors in outputs - variables: (optional) List of variables to watch during function execution. """ def __init__(self, @@ -317,9 +310,36 @@ class GraphModeFunction(object): outputs, func_outputs, output_shapes, - variables=None): + variables=None, + attrs=None): + """Initialize a GraphModeFunction. + + Args: + name: str the name of the created function + input_placeholders: list of placeholder values (tensors) to feed when + calling the wrapped function. + extra_inputs: Tensor inputs this function definition closed over which + are passed as arguments. Need to track so gradients are supported + correctly. + graph: the Graph from which the operations will be pulled. Used as + a context when computing gradients. + operations: the subset of Operations in the graph used in the function + definition. + outputs: a flat list of the Tensors in the graph used as outputs to the + function + func_outputs: a possibly nested python object which will be returned by + this function. The Tensors in this structure will be replaced by their + corresponding values in outputs. + output_shapes: List of shapes of all tensors in outputs + variables: (optional) List of variables to watch during function + execution. + attrs: (optional) dict mapping names of attributes to their AttrValue + values. Attributes in `attrs` will be included in this function's + definition. + """ + self._attrs = attrs or {} defined_function = _EagerDefinedFunction( - name, graph, operations, input_placeholders, outputs) + name, graph, operations, input_placeholders, outputs, self._attrs) if len(input_placeholders) != len(defined_function.signature.input_arg): raise ValueError("Internal error: invalid lengths. %s %s" % ( len(input_placeholders), len(defined_function.signature.input_arg))) @@ -372,7 +392,7 @@ class GraphModeFunction(object): forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + captures) + filtered_outputs + captures, self._attrs) all_inputs = self._out_grad_placeholders + captures # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. @@ -386,7 +406,7 @@ class GraphModeFunction(object): bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( bname, all_inputs, [], self._graph, function_def_ops, - backward_outputs, in_gradients, output_shapes) + backward_outputs, in_gradients, output_shapes, attrs=self._attrs) def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" @@ -560,7 +580,7 @@ def _get_defun_inputs(args): return nest.pack_sequence_as(args, ret) -def _defun_internal(name, func, args, kwds): +def _defun_internal(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): @@ -598,7 +618,7 @@ def _defun_internal(name, func, args, kwds): # call to convert_to_tensor, so we manually capture all such tensors. outputs_list = _flatten(func_outputs) func_def_outputs = [ - tmp_graph.maybe_capture_tensor(x) for x in outputs_list + tmp_graph.capture(x) for x in outputs_list if x is not None ] @@ -625,9 +645,14 @@ def _defun_internal(name, func, args, kwds): for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func.func) # pylint: disable=protected-access + + attrs = {} + if compiled: + attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True) + return GraphModeFunction( fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, - func_outputs, output_shapes, variables) + func_outputs, output_shapes, variables, attrs) # Defun uses this instead of Tensor as a cache key. Using dtype because @@ -669,7 +694,7 @@ def _register(fn): # TODO(apassos): better error messages for non-hashable arguments. -def named_defun(func, name): +def named_defun(func, name, compiled=False): """Defines a function with a given name. See the documentation for `defun` for more information on the semantics of the @@ -678,6 +703,7 @@ def named_defun(func, name): Args: func: the function to be wrapped. name: the name given to it. + compiled: if true, the framework will attempt to compile func with XLA. Returns: the wrapped function. @@ -694,13 +720,13 @@ def named_defun(func, name): if cache_key not in arguments_to_functions: arguments_to_functions[cache_key] = _defun_internal( - name, func, args, kwds) + name, func, compiled, args, kwds) return arguments_to_functions[cache_key](*args) return decorated -def defun(func): +def defun(func=None, compiled=False): """Decorator to compile func into graph_mode. `defun` converts a function that constructs a TensorFlow graph into a function @@ -743,18 +769,45 @@ def defun(func): ``` Args: - func: function to be compiled. + func: function to be compiled. If `func` is None, returns a + decorator that can be invoked with a single argument - `func`. The + end result is equivalent to providing all the arguments up front. + In other words, defun(compiled=True)(func) is equivalent to + defun(func, compiled=True). The former allows the following use case: + @tfe.defun(compiled=True) + def foo(...): + ... + compiled: If True, an attempt to compile `func` with XLA will be made. + If it fails, function will be run normally. Experimental. + Currently, supported only for execution on TPUs. Returns: - A callable that will execute the compiled function (and return zero - or more `tf.Tensor` objects). + If `func` is not None, returns callable that will execute the compiled + function (and return zero or more `tf.Tensor` objects). + If `func` is None, returns a decorator that, when invoked with a single + `func` argument, returns a callable equivalent to the case above. """ # TODO(apassos): deal with captured global state. Deal with control flow. - try: - name = func.__name__ - except AttributeError: - name = "function" - return tf_decorator.make_decorator(func, named_defun(func, name)) + def decorated(function): + try: + name = function.__name__ + except AttributeError: + name = "function" + return tf_decorator.make_decorator( + function, named_defun(function, name, compiled=compiled)) + + # This code path is for the `foo = tfe.defun(foo, ...)` use case + if func is not None: + return decorated(func) + + # This code path is for the + # + # @tfe.defun(...) + # def foo(...): + # ... + # + # use case, which is equivalent to `foo = tfe.defun(...)(foo)` + return decorated def make_defun_op(func, *args, **kwds): @@ -806,7 +859,7 @@ def make_defun_op(func, *args, **kwds): name = func.__name__ if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") - return _defun_internal(name, func, args, kwds) + return _defun_internal(name, func, False, args, kwds) class AutomaticControlDependencies(object): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 185f6d981c..f53d6c2608 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -771,6 +771,21 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) + def testDefunWhileLoopWithCapturedLoopVars(self): + n = 3 + x = constant_op.constant(list(range(n))) + + @function.defun + def loop(): + c = lambda i, x: i < n + b = lambda i, x: (i + 1, x + 1) + i, out = control_flow_ops.while_loop(c, b, (0, x)) + return i, out + + i, out = loop() + self.assertEqual(int(i), 3) + self.assertAllEqual(out, [3, 4, 5]) + def testDecorator(self): with context.graph_mode(), self.test_session(): v = resource_variable_ops.ResourceVariable(1.0) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index b5b4e394e3..b3aadd55ce 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -650,6 +650,12 @@ tensorflow::int64 EagerTensor_id(const PyObject* tensor) { return reinterpret_cast<const EagerTensor*>(tensor)->id; } +tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) { + CHECK(EagerTensor_CheckExact(tensor)); + return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType( + reinterpret_cast<const EagerTensor*>(tensor)->handle)); +} + PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { if (!PyType_Check(base_class)) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h index 63ab1ed84d..88982b0c85 100644 --- a/tensorflow/python/eager/pywrap_tensor.h +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -21,6 +21,7 @@ limitations under the License. bool EagerTensor_CheckExact(const PyObject* o); tensorflow::int64 EagerTensor_id(const PyObject* tensor); +tensorflow::DataType EagerTensor_dtype(const PyObject* tensor); namespace tensorflow { TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 4ecba1a46b..48a5b21dc7 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -843,6 +843,24 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) { return id; } +static tensorflow::DataType FastTensorDtype(PyObject* tensor) { + if (EagerTensor_CheckExact(tensor)) { + return EagerTensor_dtype(tensor); + } + PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype"); + if (dtype_field == nullptr) { + return tensorflow::DT_INVALID; + } + PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum"); + Py_DECREF(dtype_field); + if (dtype_field == nullptr) { + return tensorflow::DT_INVALID; + } + tensorflow::int64 id = MakeInt(enum_field); + Py_DECREF(enum_field); + return static_cast<tensorflow::DataType>(id); +} + class GradientTape : public tensorflow::eager::GradientTape<PyObject, PyObject> { public: @@ -1053,15 +1071,18 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { // TODO(apassos) consider not building a list and changing the API to check // each tensor individually. std::vector<tensorflow::int64> tensor_ids; + std::vector<tensorflow::DataType> dtypes; tensor_ids.reserve(len); + dtypes.reserve(len); for (int i = 0; i < len; ++i) { PyObject* item = PySequence_Fast_GET_ITEM(seq, i); tensor_ids.push_back(FastTensorId(item)); + dtypes.push_back(FastTensorDtype(item)); } Py_DECREF(seq); auto tape_set = *tape_set_ptr; for (TFE_Py_Tape* tape : tape_set) { - if (tape->tape->ShouldRecord(tensor_ids)) { + if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { Py_RETURN_TRUE; } } @@ -1169,9 +1190,27 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { } namespace { -void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, - const std::vector<tensorflow::int64>& input_ids, - PyObject* backward_function) { +std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { + PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); + if (seq == nullptr) { + return {}; + } + int len = PySequence_Fast_GET_SIZE(seq); + std::vector<tensorflow::DataType> list; + list.reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); + list.push_back(FastTensorDtype(tensor)); + } + Py_DECREF(seq); + return list; +} + +void TapeSetRecordOperation( + PyObject* op_type, PyObject* output_tensors, + const std::vector<tensorflow::int64>& input_ids, + const std::vector<tensorflow::DataType>& input_dtypes, + PyObject* backward_function) { std::vector<tensorflow::eager::TapeTensor> output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -1206,7 +1245,7 @@ void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, for (TFE_Py_Tape* tape : SafeTapeSet()) { Py_INCREF(backward_function); tape->tape->RecordOperation( - op_type_str, output_info, input_ids, backward_function, + op_type_str, output_info, input_ids, input_dtypes, backward_function, [backward_function]() { Py_DECREF(backward_function); }); } } @@ -1221,7 +1260,11 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); if (PyErr_Occurred()) return; - TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function); + std::vector<tensorflow::DataType> input_dtypes = + MakeTensorDtypeList(input_tensors); + if (PyErr_Occurred()) return; + TapeSetRecordOperation(op_type, output_tensors, input_ids, input_dtypes, + backward_function); } void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { @@ -1710,10 +1753,12 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, PyObject* results, PyObject* name) { std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs); if (PyErr_Occurred()) return nullptr; + std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); + if (PyErr_Occurred()) return nullptr; bool should_record = false; for (TFE_Py_Tape* tape : SafeTapeSet()) { - if (tape->tape->ShouldRecord(input_ids)) { + if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { should_record = true; break; } @@ -1744,7 +1789,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, Py_DECREF(callback_args); if (backward_function == nullptr) return nullptr; - TapeSetRecordOperation(op_name, results, input_ids, backward_function); + TapeSetRecordOperation(op_name, results, input_ids, input_dtypes, + backward_function); Py_DECREF(backward_function); diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 56dec1eaa1..2d9a084bc6 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -91,6 +91,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:tag_constants", "@six_archive//:six", ], ) @@ -488,6 +489,7 @@ py_library( py_test( name = "estimator_test", srcs = ["estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["notsan"], # b/67510291 deps = [ diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 973a6ec747..e7fbf8eb72 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -154,6 +154,59 @@ def _dnn_model_fn(features, Raises: ValueError: If features has the wrong type. """ + tpu_estimator_spec = _tpu_dnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + hidden_units=hidden_units, + feature_columns=feature_columns, + optimizer=optimizer, + activation_fn=activation_fn, + dropout=dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + return tpu_estimator_spec.as_estimator_spec() + + +def _tpu_dnn_model_fn(features, + labels, + mode, + head, + hidden_units, + feature_columns, + optimizer='Adagrad', + activation_fn=nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None): + """Deep Neural Net model_fn for TPUEstimator. + + Args: + features: dict of `Tensor`. + labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of + dtype `int32` or `int64` in the range `[0, n_classes)`. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `head_lib._Head` instance. + hidden_units: Iterable of integer number of hidden units per layer. + feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use the Adagrad + optimizer with a default learning rate of 0.05. + activation_fn: Activation function applied to each layer. + dropout: When not `None`, the probability we will drop out a given + coordinate. + input_layer_partitioner: Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Returns: + A `model_fn.TPUEstimatorSpec` instance. + + Raises: + ValueError: If features has the wrong type. + """ if not isinstance(features, dict): raise ValueError('features should be a dictionary of `Tensor`s. ' 'Given type: {}'.format(type(features))) @@ -182,7 +235,7 @@ def _dnn_model_fn(features, input_layer_partitioner=input_layer_partitioner) logits = logit_fn(features=features, mode=mode) - return head.create_estimator_spec( + return head._create_tpu_estimator_spec( # pylint: disable=protected-access features=features, mode=mode, labels=labels, @@ -320,17 +373,8 @@ class DNNClassifier(estimator.Estimator): loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. """ - if n_classes == 2: - head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access - weight_column=weight_column, - label_vocabulary=label_vocabulary, - loss_reduction=loss_reduction) - else: - head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access - n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary, - loss_reduction=loss_reduction) - + head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access + n_classes, weight_column, label_vocabulary, loss_reduction) def _model_fn(features, labels, mode, config): """Call the defined shared _dnn_model_fn.""" return _dnn_model_fn( diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 62b13c3200..06a648777f 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -134,7 +134,7 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): hidden_weights_names + hidden_biases_names + [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0']) - def _create_estimator_spec( + def _create_tpu_estimator_spec( features, mode, logits, labels, train_op_fn=None, optimizer=None): del features, labels # Not used. trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) @@ -149,19 +149,29 @@ def mock_head(testcase, hidden_units, logits_dimension, expected_logits): train_op = train_op_fn(loss) elif optimizer is not None: train_op = optimizer.minimize(loss, global_step=None) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op) elif mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec(mode=mode, loss=array_ops.identity(loss)) + return model_fn._TPUEstimatorSpec( + mode=mode, loss=array_ops.identity(loss)) elif mode == model_fn.ModeKeys.PREDICT: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( mode=mode, predictions={'logits': array_ops.identity(logits)}) else: testcase.fail('Invalid mode: {}'.format(mode)) + def _create_estimator_spec( + features, mode, logits, labels, train_op_fn=None, optimizer=None): + tpu_spec = _create_tpu_estimator_spec( + features, mode, logits, labels, train_op_fn, optimizer) + return tpu_spec.as_estimator_spec() + head = test.mock.NonCallableMagicMock(spec=head_lib._Head) head.logits_dimension = logits_dimension - head.create_estimator_spec = test.mock.MagicMock(wraps=_create_estimator_spec) + head._create_tpu_estimator_spec = test.mock.MagicMock( + wraps=_create_tpu_estimator_spec) + head.create_estimator_spec = test.mock.MagicMock( + wraps=_create_estimator_spec) return head diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 48f448d7f5..232637314d 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -32,6 +32,7 @@ from tensorflow.python.feature_column import feature_column as feature_column_li from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -69,6 +70,35 @@ def _summary_key(head_name, val): return '%s/%s' % (val, head_name) if head_name else val +def _create_eval_metrics_tuple(fn, kwargs): + """Creates TPU eval metrics tuple. + + Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used + by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take + exclusively Tensor arguments. This helper can help create such a function from + a more generic function that can take both Tensor and non-Tensor arguments. + + Args: + fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments. + This function must return a dict of form + {'metric name': (metric_tensor, eval_op)} + kwargs: Dict of arguments for `fn`. + + Returns: + `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`. + """ + tensor_kwargs = {} + nontensor_kwargs = {} + for k, v in six.iteritems(kwargs): + if tensor_util.is_tensor(v): + tensor_kwargs[k] = v + else: + nontensor_kwargs[k] = v + def _fn(**tensors): + return fn(**dict(nontensor_kwargs, **tensors)) + return (_fn, tensor_kwargs) + + class _Head(object): """Interface for the head/top of a model. @@ -174,7 +204,6 @@ class _Head(object): # TODO(b/65403806): By default, collect regularization_losses from # GraphKeys.REGULARIZATION_LOSSES collection. - @abc.abstractmethod def create_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): @@ -203,7 +232,47 @@ class _Head(object): Returns: `EstimatorSpec`. """ - raise NotImplementedError('Calling an abstract method.') + try: + tpu_estimator_spec = ( + self._create_tpu_estimator_spec( + features, mode, logits, labels, optimizer, train_op_fn, + regularization_losses)) + return tpu_estimator_spec.as_estimator_spec() + except NotImplementedError: + # Not all subclasses of _Head will have implemented + # _create_tpu_estimator_spec. If it is implemented, we can use it to + # create our `EstimatorSpec` here. + raise NotImplementedError( + 'Subclasses of _Head must implement `create_estimator_spec()` or ' + '_create_tpu_estimator_spec().') + + def _create_tpu_estimator_spec( + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): + """Returns `model_fn._TPUEstimatorSpec` that a model_fn can return. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` to be used by the head. + labels: Labels `Tensor`, or `dict` of same. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns an op + to optimize the model with the loss in TRAIN mode. Used if `optimizer` + is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in + TRAIN mode. None is allowed in other modes. If you want to optimize loss + yourself you can pass `lambda _: tf.no_op()` and then use + EstimatorSpec.loss to compute and apply gradients. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. + + Returns: + A `model_fn._TPUEstimatorSpec' instance. + """ + raise NotImplementedError( + 'TPUEstimatorSpec not available for this model head.') def _check_dense_labels_match_logits_and_reshape( @@ -702,10 +771,10 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): weights=weights, processed_labels=label_ids) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): - """Returns an `EstimatorSpec`. + """Returns a `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -727,7 +796,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + A `model_fn._TPUEstimatorSpec` instance. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -761,7 +830,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): classifier_output = _classification_output( scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -781,16 +850,17 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): regularized_training_loss = training_loss # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=label_ids, - class_ids=class_ids, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=_create_eval_metrics_tuple(self._eval_metric_ops, { + 'labels': label_ids, + 'class_ids': class_ids, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss + })) # Train. if optimizer is not None: @@ -824,7 +894,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): summary.scalar( _summary_key(self._name, keys.LOSS_REGULARIZATION), regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -1060,7 +1130,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): weights=weights, processed_labels=labels) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. @@ -1122,7 +1192,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): classifier_output = _classification_output( scores=probabilities, n_classes=2, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -1146,18 +1216,22 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - logits=logits, - logistic=logistic, - class_ids=class_ids, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=_create_eval_metrics_tuple( + self._eval_metric_ops, + { + 'labels': processed_labels, + 'logits': logits, + 'logistic': logistic, + 'class_ids': class_ids, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss + } + )) # Train. if optimizer is not None: @@ -1190,7 +1264,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): summary.scalar( _summary_key(self._name, keys.LOSS_REGULARIZATION), regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -1322,7 +1396,25 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): weights=weights, processed_labels=labels) - def create_estimator_spec( + def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss): + """Returns the Eval metric ops.""" + keys = metric_keys.MetricKeys + # Estimator already adds a metric for loss. + eval_metric_ops = { + _summary_key(self._name, keys.LOSS_MEAN): + metrics_lib.mean( + values=unreduced_loss, + weights=weights) + } + if regularization_loss is not None: + regularization_loss_key = _summary_key( + self._name, keys.LOSS_REGULARIZATION) + eval_metric_ops[regularization_loss_key] = metrics_lib.mean( + values=regularization_loss, + name=keys.LOSS_REGULARIZATION) + return eval_metric_ops + + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. @@ -1348,7 +1440,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + A `model_fn._TPUEstimatorSpec` instance. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -1369,7 +1461,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): if mode == model_fn.ModeKeys.PREDICT: regression_output = export_output.RegressionOutput( value=predicted_value) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -1390,25 +1482,18 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): # Eval. if mode == model_fn.ModeKeys.EVAL: - keys = metric_keys.MetricKeys - # Estimator already adds a metric for loss. - eval_metric_ops = { - _summary_key(self._name, keys.LOSS_MEAN): - metrics_lib.mean( - values=unreduced_loss, - weights=weights) - } - if regularization_loss is not None: - regularization_loss_key = _summary_key( - self._name, keys.LOSS_REGULARIZATION) - eval_metric_ops[regularization_loss_key] = metrics_lib.mean( - values=regularization_loss, - name=keys.LOSS_REGULARIZATION) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=eval_metric_ops) + eval_metrics=_create_eval_metrics_tuple( + self._eval_metric_ops, + { + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + } + )) # Train. if optimizer is not None: @@ -1441,7 +1526,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): summary.scalar( _summary_key(self._name, keys.LOSS_REGULARIZATION), regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint: disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -1478,3 +1563,42 @@ def _weights(features, weight_column): raise ValueError('Weight column should be castable to float. ' 'Given dtype: {}'.format(weights.dtype)) return math_ops.to_float(weights, name='weights') + + +def _binary_logistic_or_multi_class_head( + n_classes, weight_column, label_vocabulary, loss_reduction): + """Creates either binary or multi-class head. + + Args: + n_classes: Number of label classes. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM`. + + Returns: + `head._Head` instance. + """ + if n_classes == 2: + head = _binary_logistic_head_with_sigmoid_cross_entropy_loss( + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + else: + head = _multi_class_head_with_softmax_cross_entropy_loss( + n_classes, weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + return head diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 32a6339936..ecca3e8b0d 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -86,6 +86,98 @@ def _sigmoid(logits): return 1 / (1 + np.exp(-logits)) +class CreateEstimatorSpecTest(test.TestCase): + + class _HeadWithTPUSupport(head_lib._Head): + """Head that overrides _create_tpu_estimator_spec.""" + + def name(self): + return 'HeadWithTPUSupport' + + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + return None + + def _create_tpu_estimator_spec(self, features, mode, logits, labels=None, + optimizer=None, train_op_fn=None, + regularization_losses=None): + return model_fn._TPUEstimatorSpec( + mode=model_fn.ModeKeys.EVAL, + loss=constant_op.constant(0.0, dtype=dtypes.float32)) + + class _HeadWithOutTPUSupport(head_lib._Head): + """Head that overrides create_estimator_spec.""" + + def name(self): + return 'HeadWithOutTPUSupport' + + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + return None + + def create_estimator_spec(self, features, mode, logits, labels=None, + optimizer=None, train_op_fn=None, + regularization_losses=None): + return model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.EVAL, + loss=constant_op.constant(0.0, dtype=dtypes.float32)) + + class _InvalidHead(head_lib._Head): + """Head that overrides neither estimator_spec functions.""" + + def name(self): + return 'InvalidHead' + + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + return None + + def test_head_override_tpu_estimator_spec(self): + """Test for `_Head` that overrides _create_tpu_estimator_spec.""" + head = self._HeadWithTPUSupport() + + tpu_spec = head._create_tpu_estimator_spec( + features=None, mode=None, logits=None) + self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec)) + est_spec = head.create_estimator_spec( + features=None, mode=None, logits=None) + self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec)) + + def test_head_override_estimator_spec(self): + """Test for `_Head` that overrides create_estimator_spec.""" + head = self._HeadWithOutTPUSupport() + + with self.assertRaisesRegexp( + NotImplementedError, + 'TPUEstimatorSpec not available for this model head.'): + _ = head._create_tpu_estimator_spec( + features=None, mode=None, logits=None) + est_spec = head.create_estimator_spec( + features=None, mode=None, logits=None) + self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec)) + + def test_invalid_head_class(self): + head = self._InvalidHead() + + with self.assertRaisesRegexp( + NotImplementedError, + 'TPUEstimatorSpec not available for this model head.'): + _ = head._create_tpu_estimator_spec( + features=None, mode=None, logits=None) + with self.assertRaisesRegexp( + NotImplementedError, + r'Subclasses of _Head must implement `create_estimator_spec\(\)` or ' + r'_create_tpu_estimator_spec\(\).'): + _ = head.create_estimator_spec( + features=None, mode=None, logits=None) + + class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): def setUp(self): diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index cc8023a5e7..64457eb1ff 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -37,9 +37,8 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util -from tensorflow.python.estimator.export.export import build_all_signature_defs -from tensorflow.python.estimator.export.export import get_temp_export_dir -from tensorflow.python.estimator.export.export import get_timestamped_export_dir +from tensorflow.python.estimator.export import export as export_helpers +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -51,7 +50,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import device_setter @@ -609,73 +607,283 @@ class Estimator(object): are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long + return self._export_saved_model_for_mode( + export_dir_base, + serving_input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT) + + def _export_all_saved_models( + self, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode. + exported = {} + for mode, input_receiver_fn in input_receiver_fn_map.items(): + export_mode_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(mode)) + gfile.MakeDirs(export_mode_dir) + + exported_path = self._export_saved_model_for_mode( + export_mode_dir, + input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + + exported[mode] = exported_path + + return exported + + def _export_saved_model_for_mode( + self, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long with context.graph_mode(): - if serving_input_receiver_fn is None: - raise ValueError('serving_input_receiver_fn must be defined.') + if not input_receiver_fn: + raise ValueError('An input_receiver_fn must be defined.') - with ops.Graph().as_default() as g: - self._create_and_assert_global_step(g) - random_seed.set_random_seed(self._config.tf_random_seed) - serving_input_receiver = serving_input_receiver_fn() + if not checkpoint_path: + # Locate the latest checkpoint + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise ValueError("Couldn't find trained model at %s." % self._model_dir) - # Call the model_fn and collect the export_outputs. - estimator_spec = self._call_model_fn( - features=serving_input_receiver.features, - labels=None, - mode=model_fn_lib.ModeKeys.PREDICT, - config=self.config) - - # Build the SignatureDefs from receivers and all outputs - signature_def_map = build_all_signature_defs( - serving_input_receiver.receiver_tensors, - estimator_spec.export_outputs, - serving_input_receiver.receiver_tensors_alternatives) - - if not checkpoint_path: - # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) - if not checkpoint_path: - raise ValueError( - "Couldn't find trained model at %s." % self._model_dir) - - export_dir = get_timestamped_export_dir(export_dir_base) - temp_export_dir = get_temp_export_dir(export_dir) - - # TODO(soergel): Consider whether MonitoredSession makes sense here - with tf_session.Session(config=self._session_config) as session: - - saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - sharded=True) - saver_for_restore.restore(session, checkpoint_path) - - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) - - # Perform the export - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) - builder.add_meta_graph_and_variables( - session, [tag_constants.SERVING], - signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=local_init_op, - strip_default_attrs=strip_default_attrs) - builder.save(as_text) - - # Add the extra assets - if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) - for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) - dest_path = os.path.dirname(dest_absolute) - gfile.MakeDirs(dest_path) - gfile.Copy(source, dest_absolute) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir + export_dir = export_helpers.get_timestamped_export_dir(export_dir_base) + temp_export_dir = export_helpers.get_temp_export_dir(export_dir) + + builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + + self._add_meta_graph_and_variables_for_mode( + builder, input_receiver_fn, checkpoint_path, + strip_default_attrs, mode) + + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + gfile.Rename(temp_export_dir, export_dir) + return export_dir + + def _add_meta_graph_and_variables_for_mode( + self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Loads variables and adds them along with a MetaGraphDef for saving. + + Args: + builder: instance of SavedModelBuilder that will be used for saving. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating which mode will be exported. + """ + # pylint: enable=line-too-long + with ops.Graph().as_default() as g: + self._create_and_assert_global_step(g) + random_seed.set_random_seed(self._config.tf_random_seed) + + input_receiver = input_receiver_fn() + + # Call the model_fn and collect the export_outputs. + estimator_spec = self._call_model_fn( + features=input_receiver.features, + labels=getattr(input_receiver, 'labels', None), + mode=mode, + config=self.config) + + export_outputs = self._get_export_outputs_for_spec(estimator_spec) + + # Build the SignatureDefs from receivers and all outputs + signature_def_map = export_helpers.build_all_signature_defs( + input_receiver.receiver_tensors, + export_outputs, + getattr(input_receiver, 'receiver_tensors_alternatives', None), + serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) + + with tf_session.Session(config=self._session_config) as session: + + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] + + local_init_op = ( + estimator_spec.scaffold.local_init_op or + monitored_session.Scaffold.default_local_init_op()) + + saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + # We add the train op explicitly for now, so that we don't have to + # change the Builder public interface. Note that this is a no-op + # for prediction, where train_op is None. + builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access + + builder.add_meta_graph_and_variables( + session, + tags=export_tags, + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + strip_default_attrs=strip_default_attrs, + legacy_init_op=local_init_op) + + def _get_export_outputs_for_spec(self, estimator_spec): + """Given an EstimatorSpec, determine what our export outputs should be. + + EstimatorSpecs contain export_outputs that are used for serving, but for + training and eval graphs, we must wrap the tensors of interest in + appropriate ExportOutput objects. + + Args: + estimator_spec: EstimatorSpec object that will be exported. + + Returns: + a dict mapping export_output_name to ExportOutput object. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the + passed EstimatorSpec.mode + """ + mode = estimator_spec.mode + if mode == model_fn_lib.ModeKeys.PREDICT: + outputs = estimator_spec.export_outputs + else: + if mode == model_fn_lib.ModeKeys.TRAIN: + output_class = export_output.TrainOutput + elif mode == model_fn_lib.ModeKeys.EVAL: + output_class = export_output.EvalOutput + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) + + export_out = output_class( + loss=estimator_spec.loss, + predictions=estimator_spec.predictions, + metrics=estimator_spec.eval_metric_ops) + outputs = {mode: export_out} + + return outputs def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" @@ -1544,3 +1752,5 @@ def _get_default_warm_start_settings(warm_start_from): else: raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' 'instead got {}'.format(type(warm_start_from))) + + diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 76b45b7f57..02088e5134 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1865,6 +1865,41 @@ def _model_fn_for_export_tests(features, labels, mode): 'test': export_output.ClassificationOutput(scores, classes)}) +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + def _model_fn_with_saveables_for_export_tests(features, labels, mode): _, _ = features, labels table = saver_test_utils.CheckpointedOp(name='v2') @@ -1881,21 +1916,41 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode): 'test': export_output.PredictOutput({'prediction': prediction})}) +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn(feature_spec, label_spec) + + _VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' _EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' class EstimatorExportTest(test.TestCase): - def test_export_savedmodel_proto_roundtrip(self): - tmpdir = tempfile.mkdtemp() - est = estimator.Estimator(model_fn=_model_fn_for_export_tests) - est.train(input_fn=dummy_input_fn, steps=1) + def test_export_savedmodel_proto_roundtrip_raw_receiver(self): feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) @@ -1904,6 +1959,266 @@ class EstimatorExportTest(test.TestCase): # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('weight' in graph_ops) + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(os.path.join( compat.as_bytes(export_dir), @@ -1918,18 +2233,6 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) - # Restore, to validate that the export was well-formed. - with ops.Graph().as_default() as graph: - with session.Session(graph=graph) as sess: - loader.load(sess, [tag_constants.SERVING], export_dir) - graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('weight' in graph_ops) - - # Clean up. - gfile.DeleteRecursively(tmpdir) - def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2485,5 +2788,6 @@ class EstimatorIntegrationTest(test.TestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 41c1f5a2e2..9aafb56679 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,6 +40,60 @@ from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +_SINGLE_LABEL_DEFAULT_NAME = 'label' + + +def _wrap_and_check_receiver_tensors(receiver_tensors): + """Ensure that receiver_tensors is a dict of str to Tensor mappings. + + Args: + receiver_tensors: dict of str to Tensors, or a single Tensor. + + Returns: + dict of str to Tensors; this is the original dict if one was passed, or + the original tensor wrapped in a dictionary. + + Raises: + ValueError: if receiver_tensors is None, or has non-string keys, + or non-Tensor values + """ + if receiver_tensors is None: + raise ValueError('receiver_tensors must be defined.') + if not isinstance(receiver_tensors, dict): + receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} + for name, tensor in receiver_tensors.items(): + _check_tensor_key(name, error_label='receiver_tensors') + _check_tensor(tensor, name, error_label='receiver_tensor') + return receiver_tensors + + +def _check_tensor(tensor, name, error_label='feature'): + """Check that passed `tensor` is a Tensor or SparseTensor.""" + if not (isinstance(tensor, ops.Tensor) + or isinstance(tensor, sparse_tensor.SparseTensor)): + fmt_name = ' {}'.format(name) if name else '' + value_error = ValueError( + '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name)) + # NOTE(ericmc): This if-else block is a specific carve-out for + # LabeledTensor, which has a `.tensor` attribute and which is + # convertible to tf.Tensor via ops.convert_to_tensor. + # Allowing all types convertible to tf.Tensor is considered by soergel@ + # to be too permissive. + # TODO(soergel): accept any type convertible to Tensor, + # as in cl/193238295 snapshot #6. + if hasattr(tensor, 'tensor'): + try: + ops.convert_to_tensor(tensor) + except TypeError: + raise value_error + else: + raise value_error + + +def _check_tensor_key(name, error_label='feature'): + if not isinstance(name, six.string_types): + raise ValueError( + '{} keys must be strings: {}.'.format(error_label, name)) @tf_export('estimator.export.ServingInputReceiver') @@ -51,16 +105,18 @@ class ServingInputReceiver(collections.namedtuple( The expected return values are: features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or `SparseTensor`, specifying the features to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, @@ -70,36 +126,10 @@ class ServingInputReceiver(collections.namedtuple( if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, six.string_types): - raise ValueError('feature keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - value_error = ValueError( - 'feature {} must be a Tensor or SparseTensor.'.format(name)) - # NOTE(ericmc): This if-else block is a specific carve-out for - # LabeledTensor, which has a `.tensor` attribute and which is - # convertible to tf.Tensor via ops.convert_to_tensor. - # Allowing all types convertible to tf.Tensor is considered by soergel@ - # to be too permissive. - if hasattr(tensor, 'tensor'): - try: - ops.convert_to_tensor(tensor) - except TypeError: - raise value_error - else: - raise value_error - - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not isinstance(tensor, ops.Tensor): - raise ValueError( - 'receiver_tensor {} must be a Tensor.'.format(name)) + _check_tensor_key(name) + _check_tensor(tensor, name) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -115,14 +145,9 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives[alternative_name] = ( receiver_tensors_alt) for name, tensor in receiver_tensors_alt.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - raise ValueError( - 'receiver_tensor {} must be a Tensor or SparseTensor.'.format( - name)) + _check_tensor_key(name, error_label='receiver_tensors_alternative') + _check_tensor( + tensor, name, error_label='receiver_tensors_alternative') return super(ServingInputReceiver, cls).__new__( cls, @@ -155,25 +180,25 @@ class TensorServingInputReceiver(collections.namedtuple( The expected return values are: features: A single `Tensor` or `SparseTensor`, representing the feature to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, receiver_tensors_alternatives=None): if features is None: raise ValueError('features must be defined.') - if not (isinstance(features, ops.Tensor) - or isinstance(features, sparse_tensor.SparseTensor)): - raise ValueError('feature must be a Tensor or SparseTensor.') + _check_tensor(features, None) receiver = ServingInputReceiver( features=features, @@ -187,6 +212,49 @@ class TensorServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver.receiver_tensors_alternatives) +class SupervisedInputReceiver(collections.namedtuple( + 'SupervisedInputReceiver', + ['features', 'labels', 'receiver_tensors'])): + """A return type for a training_input_receiver_fn or eval_input_receiver_fn. + + This differs from a ServingInputReceiver in that (1) this receiver expects + a set of labels to be passed in with features, and (2) this receiver does + not support receiver_tensors_alternatives, which are primarily used for + serving. + + The expected return values are: + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. + + """ + + def __new__(cls, features, labels, receiver_tensors): + # Both features and labels can be dicts or raw tensors. + for input_vals, error_label in ((features, 'feature'), (labels, 'label')): + if input_vals is None: + raise ValueError('{}s must be defined.'.format(error_label)) + if isinstance(input_vals, dict): + for name, tensor in input_vals.items(): + _check_tensor_key(name, error_label=error_label) + _check_tensor(tensor, name, error_label=error_label) + else: + _check_tensor(input_vals, None, error_label=error_label) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + + return super(SupervisedInputReceiver, cls).__new__( + cls, + features=features, + labels=labels, + receiver_tensors=receiver_tensors) + + @tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): @@ -216,6 +284,23 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +def _placeholder_from_tensor(t, default_batch_size=None): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + # Reuse the feature tensor's op name (t.op.name) for the placeholder, + # excluding the index from the tensor's name (t.name): + # t.name = "%s:%d" % (t.op.name, t._value_index) + return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name) + + +def _placeholders_from_receiver_tensors_dict( + input_vals, default_batch_size=None): + return {name: _placeholder_from_tensor(t, default_batch_size) + for name, t in input_vals.items()} + + @tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. @@ -233,17 +318,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """ def serving_input_receiver_fn(): """A serving_input_receiver_fn that expects features to be fed directly.""" - receiver_tensors = {} - for name, t in features.items(): - shape_list = t.get_shape().as_list() - shape_list[0] = default_batch_size - shape = tensor_shape.TensorShape(shape_list) - - # Reuse the feature tensor's op name (t.op.name) for the placeholder, - # excluding the index from the tensor's name (t.name): - # t.name = "%s:%d" % (t.op.name, t._value_index) - receiver_tensors[name] = array_ops.placeholder( - dtype=t.dtype, shape=shape, name=t.op.name) + receiver_tensors = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + # TODO(b/34885899): remove the unnecessary copy # The features provided are simply the placeholders, but we defensively copy # the dict because it may be mutated. @@ -252,13 +329,100 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): return serving_input_receiver_fn +def build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=None): + """Build a supervised_input_receiver_fn for raw features and labels. + + This function wraps tensor placeholders in a supervised_receiver_fn + with the expectation that the features and labels appear precisely as + the model_fn expects them. Features and labels can therefore be dicts of + tensors, or raw tensors. + + Args: + features: a dict of string to `Tensor` or `Tensor`. + labels: a dict of string to `Tensor` or `Tensor`. + default_batch_size: the number of query examples expected per batch. + Leave unset for variable batch size (recommended). + + Returns: + A supervised_input_receiver_fn. + + Raises: + ValueError: if features and labels have overlapping keys. + """ + # Check for overlapping keys before beginning. + try: + feat_keys = features.keys() + except AttributeError: + feat_keys = [_SINGLE_RECEIVER_DEFAULT_NAME] + try: + label_keys = labels.keys() + except AttributeError: + label_keys = [_SINGLE_LABEL_DEFAULT_NAME] + + overlap_keys = set(feat_keys) & set(label_keys) + if overlap_keys: + raise ValueError('Features and labels must have distinct keys. ' + 'Found overlapping keys: {}'.format(overlap_keys)) + + def supervised_input_receiver_fn(): + """A receiver_fn that expects pass-through features and labels.""" + if not isinstance(features, dict): + features_cp = _placeholder_from_tensor(features, default_batch_size) + receiver_features = {_SINGLE_RECEIVER_DEFAULT_NAME: features_cp} + else: + receiver_features = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + features_cp = receiver_features + + if not isinstance(labels, dict): + labels_cp = _placeholder_from_tensor(labels, default_batch_size) + receiver_labels = {_SINGLE_LABEL_DEFAULT_NAME: labels_cp} + else: + receiver_labels = _placeholders_from_receiver_tensors_dict( + labels, default_batch_size) + labels_cp = receiver_labels + + receiver_tensors = dict(receiver_features) + receiver_tensors.update(receiver_labels) + return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors) + + return supervised_input_receiver_fn + + ### Below utilities are specific to SavedModel exports. def build_all_signature_defs(receiver_tensors, export_outputs, - receiver_tensors_alternatives=None): - """Build `SignatureDef`s for all export outputs.""" + receiver_tensors_alternatives=None, + serving_only=True): + """Build `SignatureDef`s for all export outputs. + + Args: + receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying + input nodes where this receiver expects to be fed by default. Typically, + this is a single placeholder expecting serialized `tf.Example` protos. + export_outputs: a dict of ExportOutput instances, each of which has + an as_signature_def instance method that will be called to retrieve + the signature_def for all export output tensors. + receiver_tensors_alternatives: a dict of string to additional + groups of receiver tensors, each of which may be a `Tensor` or a dict of + string to `Tensor`. These named receiver tensor alternatives generate + additional serving signatures, which may be used to feed inputs at + different points within the input receiver subgraph. A typical usage is + to allow feeding raw feature `Tensor`s *downstream* of the + tf.parse_example() op. Defaults to None. + serving_only: boolean; if true, resulting signature defs will only include + valid serving signatures. If false, all requested signatures will be + returned. + + Returns: + signature_def representing all passed args. + + Raises: + ValueError: if export_outputs is not a dict + """ if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): @@ -293,17 +457,24 @@ def build_all_signature_defs(receiver_tensors, _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only - # valid signatures; if there is a validity problem, they raise ValueError, - # which we ignore above. Consequently the call to is_valid_signature here - # should not remove anything else; it's just an extra sanity check. - return {k: v for k, v in signature_def_map.items() - if signature_def_utils.is_valid_signature(v)} + # valid signatures; if there is a validity problem, they raise a ValueError, + # in which case we exclude that signature from signature_def_map above. + # The is_valid_signature check ensures that the signatures produced are + # valid for serving, and acts as an additional sanity check for export + # signatures produced for serving. We skip this check for training and eval + # signatures, which are not intended for serving. + if serving_only: + signature_def_map = {k: v for k, v in signature_def_map.items() + if signature_def_utils.is_valid_signature(v)} + return signature_def_map _FRIENDLY_METHOD_NAMES = { signature_constants.CLASSIFY_METHOD_NAME: 'Classify', signature_constants.REGRESS_METHOD_NAME: 'Regress', signature_constants.PREDICT_METHOD_NAME: 'Predict', + signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', + signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', } diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 87b964be37..d387ea2940 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -38,6 +38,8 @@ class ExportOutput(object): __metaclass__ = abc.ABCMeta + _SEPARATOR_CHAR = '/' + @abc.abstractmethod def as_signature_def(self, receiver_tensors): """Generate a SignatureDef proto for inclusion in a MetaGraphDef. @@ -51,6 +53,52 @@ class ExportOutput(object): """ pass + def _check_output_key(self, key, error_label): + # For multi-head models, the key can be a tuple. + if isinstance(key, tuple): + key = self._SEPARATOR_CHAR.join(key) + + if not isinstance(key, six.string_types): + raise ValueError( + '{} output key must be a string; got {}.'.format(error_label, key)) + return key + + def _wrap_and_check_outputs( + self, outputs, single_output_default_name, error_label=None): + """Wraps raw tensors as dicts and checks type. + + Note that we create a new dict here so that we can overwrite the keys + if necessary. + + Args: + outputs: A `Tensor` or a dict of string to `Tensor`. + single_output_default_name: A string key for use in the output dict + if the provided `outputs` is a raw tensor. + error_label: descriptive string for use in error messages. If none, + single_output_default_name will be used. + + Returns: + A dict of tensors + + Raises: + ValueError: if the outputs dict keys are not strings or tuples of strings + or the values are not Tensors. + """ + if not isinstance(outputs, dict): + outputs = {single_output_default_name: outputs} + + output_dict = {} + for key, value in outputs.items(): + error_name = error_label or single_output_default_name + key = self._check_output_key(key, error_name) + if not isinstance(value, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + error_name, value)) + + output_dict[key] = value + return output_dict + @tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): @@ -154,9 +202,6 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -_SINGLE_OUTPUT_DEFAULT_NAME = 'output' - - @tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -165,6 +210,7 @@ class PredictOutput(ExportOutput): Named outputs must be provided as a dict from string to `Tensor`, """ + _SINGLE_OUTPUT_DEFAULT_NAME = 'output' def __init__(self, outputs): """Constructor for PredictOutput. @@ -177,16 +223,9 @@ class PredictOutput(ExportOutput): ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ - if not isinstance(outputs, dict): - outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} - for key, value in outputs.items(): - if not isinstance(key, six.string_types): - raise ValueError( - 'Prediction output key must be a string; got {}.'.format(key)) - if not isinstance(value, ops.Tensor): - raise ValueError( - 'Prediction output value must be a Tensor; got {}.'.format(value)) - self._outputs = outputs + + self._outputs = self._wrap_and_check_outputs( + outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') @property def outputs(self): @@ -195,3 +234,161 @@ class PredictOutput(ExportOutput): def as_signature_def(self, receiver_tensors): return signature_def_utils.predict_signature_def(receiver_tensors, self.outputs) + + +class _SupervisedOutput(ExportOutput): + """Represents the output of a supervised training or eval process.""" + __metaclass__ = abc.ABCMeta + + LOSS_NAME = 'loss' + PREDICTIONS_NAME = 'predictions' + METRICS_NAME = 'metrics' + + METRIC_VALUE_SUFFIX = 'value' + METRIC_UPDATE_SUFFIX = 'update_op' + + _loss = None + _predictions = None + _metrics = None + + def __init__(self, loss=None, predictions=None, metrics=None): + """Constructor for SupervisedOutput (ie, Train or Eval output). + + Args: + loss: dict of Tensors or single Tensor representing calculated loss. + predictions: dict of Tensors or single Tensor representing model + predictions. + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op. + + Raises: + ValueError: if any of the outputs' dict keys are not strings or tuples of + strings or the values are not Tensors (or Operations in the case of + update_op). + """ + + if loss is not None: + loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) + self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) + if predictions is not None: + pred_dict = self._wrap_and_check_outputs( + predictions, self.PREDICTIONS_NAME) + self._predictions = self._prefix_output_keys( + pred_dict, self.PREDICTIONS_NAME) + if metrics is not None: + self._metrics = self._wrap_and_check_metrics(metrics) + + def _prefix_output_keys(self, output_dict, output_name): + """Prepend output_name to the output_dict keys if it doesn't exist. + + This produces predictable prefixes for the pre-determined outputs + of SupervisedOutput. + + Args: + output_dict: dict of string to Tensor, assumed valid. + output_name: prefix string to prepend to existing keys. + + Returns: + dict with updated keys and existing values. + """ + + new_outputs = {} + for key, val in output_dict.items(): + key = self._prefix_key(key, output_name) + new_outputs[key] = val + return new_outputs + + def _prefix_key(self, key, output_name): + if key.find(output_name) != 0: + key = output_name + self._SEPARATOR_CHAR + key + return key + + def _wrap_and_check_metrics(self, metrics): + """Handle the saving of metrics. + + Metrics is either a tuple of (value, update_op), or a dict of such tuples. + Here, we separate out the tuples and create a dict with names to tensors. + + Args: + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + + Returns: + dict of output_names to tensors + + Raises: + ValueError: if the dict key is not a string, or the metric values or ops + are not tensors. + """ + if not isinstance(metrics, dict): + metrics = {self.METRICS_NAME: metrics} + + outputs = {} + for key, (metric_val, metric_op) in metrics.items(): + key = self._check_output_key(key, self.METRICS_NAME) + key = self._prefix_key(key, self.METRICS_NAME) + + val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX + op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX + if not isinstance(metric_val, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + key, metric_val)) + if (not isinstance(metric_op, ops.Tensor) and + not isinstance(metric_op, ops.Operation)): + raise ValueError( + '{} update_op must be a Tensor or Operation; got {}.'.format( + key, metric_op)) + outputs[val_name] = metric_val + outputs[op_name] = metric_op + + return outputs + + @property + def loss(self): + return self._loss + + @property + def predictions(self): + return self._predictions + + @property + def metrics(self): + return self._metrics + + @abc.abstractmethod + def _get_signature_def_fn(self): + """Returns a function that produces a SignatureDef given desired outputs.""" + pass + + def as_signature_def(self, receiver_tensors): + signature_def_fn = self._get_signature_def_fn() + return signature_def_fn( + receiver_tensors, self.loss, self.predictions, self.metrics) + + +class TrainOutput(_SupervisedOutput): + """Represents the output of a supervised training process. + + This class generates the appropriate signature def for exporting + training output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_train_signature_def + + +class EvalOutput(_SupervisedOutput): + """Represents the output of a supervised eval process. + + This class generates the appropriate signature def for exporting + eval output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_eval_signature_def + + + + diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 7090e53d80..b21ba91b0f 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -225,5 +225,115 @@ class ExportOutputTest(test.TestCase): }) +class MockSupervisedOutput(export_output_lib._SupervisedOutput): + """So that we can test the abstract class methods directly.""" + + def _get_signature_def_fn(self): + pass + + +class SupervisedOutputTest(test.TestCase): + + def test_supervised_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10])), + "metrics2": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"]) + self.assertEqual( + outputter.predictions["predictions/output1"], predictions["output1"]) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics2/update_op"], metrics["metrics2"][1]) + + # Single Tensor is OK too + outputter = MockSupervisedOutput( + loss["my_loss"], predictions["output1"], metrics["metrics"]) + self.assertEqual(outputter.loss, {"loss": loss["my_loss"]}) + self.assertEqual( + outputter.predictions, {"predictions": predictions["output1"]}) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + + def test_supervised_outputs_none(self): + outputter = MockSupervisedOutput( + constant_op.constant([0]), None, None) + self.assertEqual(len(outputter.loss), 1) + self.assertEqual(outputter.predictions, None) + self.assertEqual(outputter.metrics, None) + + def test_supervised_outputs_invalid(self): + with self.assertRaisesRegexp(ValueError, "predictions output value must"): + MockSupervisedOutput(constant_op.constant([0]), [3], None) + with self.assertRaisesRegexp(ValueError, "loss output value must"): + MockSupervisedOutput("str", None, None) + with self.assertRaisesRegexp(ValueError, "metrics output value must"): + MockSupervisedOutput(None, None, (15.3, 4)) + with self.assertRaisesRegexp(ValueError, "loss output key must"): + MockSupervisedOutput({25: "Tensor"}, None, None) + + def test_supervised_outputs_tuples(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {("my", "loss"): constant_op.constant([0])} + predictions = {(u"output1", "2"): constant_op.constant(["foo"])} + metrics = {("metrics", "twice"): (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"])) + self.assertEqual(set(outputter.predictions.keys()), + set(["predictions/output1/2"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/twice/value", "metrics/twice/update_op"])) + + def test_supervised_outputs_no_prepend(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"loss": constant_op.constant([0])} + predictions = {u"predictions": constant_op.constant(["foo"])} + metrics = {u"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss"])) + self.assertEqual(set(outputter.predictions.keys()), set(["predictions"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/value", "metrics/update_op"])) + + def test_train_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = export_output_lib.TrainOutput(loss, predictions, metrics) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertTrue("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + + def test_eval_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + + outputter = export_output_lib.EvalOutput(loss, predictions, None) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertFalse("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index c203be7dac..0af587f2a8 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -54,7 +54,7 @@ ops.register_tensor_conversion_function(LabeledTensorMock, _convert_labeled_tensor_mock_to_tensor) -class ExportTest(test_util.TensorFlowTestCase): +class ServingInputReceiverTest(test_util.TensorFlowTestCase): def test_serving_input_receiver_constructor(self): """Tests that no errors are raised when input is expected.""" @@ -161,6 +161,165 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): _ = export.ServingInputReceiver(feature, receiver_tensor) + +class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): + + def test_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_input_receiver_raw_values(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + rec = export.SupervisedInputReceiver( + features["feature2"], labels, receiver_tensors) + self.assertIsInstance(rec.features, sparse_tensor.SparseTensor) + + rec = export.SupervisedInputReceiver( + features, labels["classes"], receiver_tensors) + self.assertIsInstance(rec.labels, ops.Tensor) + + def test_input_receiver_features_invalid(self): + features = constant_op.constant([0] * 100) + labels = constant_op.constant([0]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.SupervisedInputReceiver( + features=None, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.SupervisedInputReceiver( + features={1: constant_op.constant([1])}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "label keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features={"feature1": [1]}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=[1], + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "label must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=features, + labels=100, + receiver_tensors=receiver_tensors) + + def test_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = constant_op.constant([0]) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={"example1": [1]}) + + def test_single_feature_single_receiver(self): + feature = constant_op.constant(5) + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + input_receiver = export.SupervisedInputReceiver( + feature, label, receiver_tensor) + + # single receiver is automatically named + receiver_key, = input_receiver.receiver_tensors.keys() + self.assertEqual("input", receiver_key) + + def test_multi_feature_single_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(features, labels, receiver_tensor) + + def test_multi_feature_multi_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64), + "qux": array_ops.placeholder(dtypes.float32)} + _ = export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_feature_labeled_tensor(self): + feature = LabeledTensorMock() + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(feature, label, receiver_tensor) + + +class ExportTest(test_util.TensorFlowTestCase): + def test_build_parsing_serving_input_receiver_fn(self): feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64), "float_feature": parsing_ops.VarLenFeature(dtypes.float32)} @@ -237,6 +396,69 @@ class ExportTest(test_util.TensorFlowTestCase): dtypes.int32, serving_input_receiver.receiver_tensors["feature_2"].dtype) + def test_build_raw_supervised_input_receiver_fn(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["feature_1", "feature_2"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + self.assertEqual( + dtypes.string, input_receiver.receiver_tensors["feature_1"].dtype) + self.assertEqual( + dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype) + + def test_build_raw_supervised_input_receiver_fn_raw_tensors(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels) + input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels["foo"]) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn1() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["input", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + + input_receiver = input_receiver_fn2() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["input", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_raw_supervised_input_receiver_fn_batch_size(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=10) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape) + self.assertEqual([10], input_receiver.features["feature_1"].shape) + + def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"feature_1": constant_op.constant([5]), + "bar": constant_op.constant([6])} + with self.assertRaises(ValueError): + export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) @@ -404,6 +626,35 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertTrue(int(time_1) < int(time_2)) self.assertTrue(int(time_2) < int(time_3)) + def test_build_all_signature_defs_serving_only(self): + receiver_tensor = {"input": array_ops.placeholder(dtypes.string)} + output_1 = constant_op.constant([1.]) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(outputs=output_1), + "train": export_output.TrainOutput(loss=output_1), + } + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs) + + expected_signature_defs = { + "serving_default": signature_def_utils.predict_signature_def( + receiver_tensor, {"output": output_1}) + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs, serving_only=False) + + expected_signature_defs.update({ + "train": signature_def_utils.supervised_train_signature_def( + receiver_tensor, loss={"loss": output_1}) + }) + + self.assertDictEqual(expected_signature_defs, signature_defs) + class TensorServingReceiverTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 8111ab564c..3edf9fe940 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest @@ -53,6 +54,13 @@ class ModeKeys(object): LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +# Mapping of the modes to appropriate tag_constants that are used for saving. +EXPORT_TAG_MAP = { + ModeKeys.PREDICT: [tag_constants.SERVING], + ModeKeys.TRAIN: [tag_constants.TRAINING], + ModeKeys.EVAL: [tag_constants.EVAL], +} + @tf_export('estimator.EstimatorSpec') class EstimatorSpec( @@ -326,6 +334,57 @@ class EstimatorSpec( return EstimatorSpec(*new_fields) +class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ + 'mode', + 'predictions', + 'loss', + 'train_op', + 'eval_metrics', + 'export_outputs', + 'scaffold_fn', + 'host_call'])): + """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. + + This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See + tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed + documentation. + """ + + def __new__(cls, + mode, + predictions=None, + loss=None, + train_op=None, + eval_metrics=None, + export_outputs=None, + scaffold_fn=None, + host_call=None): + """Creates a `_TPUEstimatorSpec` instance.""" + return super(_TPUEstimatorSpec, cls).__new__(cls, + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metrics=eval_metrics, + export_outputs=export_outputs, + scaffold_fn=scaffold_fn, + host_call=host_call) + + def as_estimator_spec(self): + """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" + if not self.eval_metrics: + eval_metric_ops = None + else: + metric_fn, tensors = self.eval_metrics + eval_metric_ops = metric_fn(**tensors) + return EstimatorSpec(mode=self.mode, + predictions=self.predictions, + loss=self.loss, + train_op=self.train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=self.export_outputs) + + def _check_is_tensor_or_operation(x, name): if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)): raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x)) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index e7f9e590af..f82e94b1a3 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -696,7 +696,7 @@ class _FuncGraph(ops.Graph): return super(_FuncGraph, self).create_op(op_type, inputs, data_types, **kwargs) - def capture(self, tensor): + def capture(self, tensor, name=None): """Adds the given tensor to this graph and returns the captured tensor.""" if tensor in self._captured: # Captured already. @@ -704,15 +704,16 @@ class _FuncGraph(ops.Graph): elif self._capture_by_value: return self._add_tensor_and_parents(tensor) else: - return self._capture_tensor_as_extra_input(tensor) + return self._capture_tensor_as_extra_input(tensor, name) - def _capture_tensor_as_extra_input(self, tensor): + def _capture_tensor_as_extra_input(self, tensor, name=None): # Substitute with a placeholder. self.extra_inputs.append(tensor) # Hoist the new input placeholder out of any control flow context # we're currently in. with ops.control_dependencies(None): - ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape()) + ph = array_ops.placeholder( + tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if ops._USE_C_SHAPES: handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph, diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 2209e8e21a..de3bf0032b 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1057,13 +1057,19 @@ def internal_convert_to_tensor(value, """ if ctx is None: ctx = context.context() - if ctx.executing_eagerly(): - # Fast path for EagerTensors that don't need any conversion. - if isinstance(value, EagerTensor): + if isinstance(value, EagerTensor): + if ctx.executing_eagerly(): + # Fast path for EagerTensors that don't need any conversion. # Note that we don't check that value's dtype matches the dtype # argument. We expect that the C runtime will do that checking # when we execute the kernel. return value + else: + graph = get_default_graph() + if not graph.building_function: + raise RuntimeError("Attempting to capture an EagerTensor without " + "building a function.") + return graph.capture(value, name=name) if dtype is not None: dtype = dtypes.as_dtype(dtype) @@ -1251,7 +1257,10 @@ def internal_convert_to_tensor_or_indexed_slices(value, Raises: ValueError: If `dtype` does not match the element type of `value`. """ - if isinstance(value, _TensorLike): + if isinstance(value, EagerTensor) and not context.executing_eagerly(): + return internal_convert_to_tensor( + value, dtype=dtype, name=name, as_ref=as_ref) + elif isinstance(value, _TensorLike): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): raise ValueError( "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" % diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 1b66f58939..523eb67935 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -395,7 +395,7 @@ py_test( py_test( name = "resnet50_test", - size = "small", + size = "medium", srcs = ["_impl/keras/applications/resnet50_test.py"], srcs_version = "PY2AND3", deps = [ @@ -563,7 +563,7 @@ py_test( py_test( name = "normalization_test", - size = "small", + size = "medium", srcs = ["_impl/keras/layers/normalization_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], @@ -604,6 +604,7 @@ py_test( name = "lstm_test", size = "medium", srcs = ["_impl/keras/layers/lstm_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "noasan", # times out b/63678675 diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 3af4eaabe9..16ee2952b2 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -1658,7 +1658,7 @@ class DeferredTensor(object): """Tensor-like object used to build graphs of layers in Eager mode. When calling a layer on a DeferredTensor, the layer will not perform any - computation and will simply perfom shape inference to return new + computation and will simply perform shape inference to return new DeferredTensors with appropriate shape information. Thus DeferredTensor behaves like a graph-mode Tensor when manipulated by layers. """ diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py index 3197d49fce..b7fab6e974 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -318,6 +318,9 @@ class Network(base_layer.Layer): layer, name='layer-%d' % layer_index, overwrite=True) def __setattr__(self, name, value): + no_dependency = isinstance(value, checkpointable.NoDependency) + if no_dependency: + value = value.value if isinstance(value, (base_layer.Layer, Network)): try: is_graph_network = self._is_graph_network @@ -332,7 +335,8 @@ class Network(base_layer.Layer): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True - if isinstance(value, checkpointable.CheckpointableBase): + if (not no_dependency + and isinstance(value, checkpointable.CheckpointableBase)): # Layer (and therefore Network/Model) inherit from CheckpointableBase # rather than Checkpointable, which means there is no Checkpointable # __setattr__ override (it would be a performance issue for functional diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py index 8aba16aef3..a90ad131a5 100644 --- a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py @@ -20,8 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.training import rmsprop @@ -75,7 +78,7 @@ class TestSequential(test.TestCase): model.pop() @tf_test_util.run_in_graph_and_eager_modes() - def test_sequential_deferred_build(self): + def test_sequential_deferred_build_with_np_arrays(self): num_hidden = 5 input_dim = 3 batch_size = 5 @@ -100,6 +103,40 @@ class TestSequential(test.TestCase): self.assertEqual(len(model.weights), 2 * 2) @tf_test_util.run_in_graph_and_eager_modes() + def test_sequential_deferred_build_with_dataset_iterators(self): + if not context.executing_eagerly(): + # TODO(psv/fchollet): Add support for this use case in graph mode. + return + num_hidden = 5 + input_dim = 3 + num_classes = 2 + num_samples = 50 + steps_per_epoch = 10 + + model = keras.models.Sequential() + # We don't specify the input shape. + model.add(keras.layers.Dense(num_hidden)) + model.add(keras.layers.Dense(num_classes)) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + self.assertEqual(len(model.layers), 2) + self.assertEqual(len(model.weights), 0) + self.assertFalse(model.built) + + x = array_ops.ones((num_samples, input_dim)) + y = array_ops.zeros((num_samples, num_classes)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + iterator = dataset.make_one_shot_iterator() + + model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch) + self.assertTrue(model.built) + self.assertEqual(model.inputs[0].get_shape().as_list(), [None, input_dim]) + self.assertEqual(model.outputs[0].get_shape().as_list(), + [None, num_classes]) + self.assertEqual(len(model.weights), 2 * 2) + + @tf_test_util.run_in_graph_and_eager_modes() def test_invalid_use_cases(self): # Added objects must be layer instances with self.assertRaises(TypeError): diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 5f9b3e8c7d..c7623d2b52 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -18,11 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import weakref import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K @@ -106,6 +108,11 @@ class Model(Network): ``` """ + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + # Create a cache for iterator get_next op. + self._iterator_get_next = weakref.WeakKeyDictionary() + def compile(self, optimizer, loss=None, @@ -623,12 +630,23 @@ class Model(Network): **kwargs) self._post_build_cleanup() + def _get_iterator_get_next_tensors(self, iterator): + get_next_op = self._iterator_get_next.get(iterator, None) + if get_next_op is None: + get_next_op = iterator.get_next() + self._iterator_get_next[iterator] = get_next_op + return get_next_op + def _standardize_user_data(self, x, y=None, sample_weight=None, class_weight=None, - batch_size=None): + batch_size=None, + check_steps=False, + steps_name='steps', + steps=None, + validation_split=0): """Runs validation checks on input and target data passed by the user. Also standardizes the data to lists of arrays, in order. @@ -660,6 +678,16 @@ class Model(Network): to, as conveyed by `y`. batch_size: Integer batch size. If provided, it is used to run additional validation checks on stateful models. + check_steps: boolean, True if we want to check for validity of `steps` and + False, otherwise. For example, when we are standardizing one batch of + data for train_on_batch/predict_on_batch/test_on_batch APIs, `steps` + value is not required and we should not check for its validity in these + cases. + steps_name: The public API's parameter name for `steps`. + steps: Integer or `None`. Total number of steps (batches of samples) to + execute. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. Returns: A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. @@ -671,33 +699,54 @@ class Model(Network): ValueError: In case of invalid user-provided data. RuntimeError: If the model was never compiled. """ - # First, we build/compile the model on the fly if necessary. if isinstance(x, dataset_ops.Dataset): raise ValueError('You passed a `Dataset` instance to your model (%s), ' 'which is not supported. Instead, pass an `Iterator`, ' 'which you can obtain e.g. via ' '`dataset.make_one_shot_iterator()` (the exact method ' 'to use will depend on your specific dataset).' % x) - if isinstance(x, iterator_ops.Iterator): - if y is not None: - raise ValueError('You passed a dataset iterator (%s) as input `x` to ' - 'your model. In that case, you should not specify ' - 'a target (`y`) argument, since the dataset iterator ' - 'generates both input data and target data. ' - 'Received: %s' % (x, y)) - if not context.executing_eagerly(): - x, y = x.get_next() - # TODO(fchollet): handle case of `get_next` not returning 2 tensors? - else: - # TODO(psv): implement this. The way to support it will be to typecheck - # for `iterator` before `_standardize_user_data` is called and redirect - # to new training/eval functions in `training_eager.py`. The model - # may need to get built using the specs of the data from the first batch - # drawn from the iterator. - raise ValueError('Dataset iterators are not supported ' - 'with eager execution yet.') + # Validates `steps` argument based on x's type. + if check_steps: + training_utils.check_steps_argument(x, steps, steps_name) + + is_x_eager_iterator = isinstance(x, iterator_ops.EagerIterator) + is_x_iterator = isinstance(x, iterator_ops.Iterator) + + # Validate user inputs when data is given as a dataset iterator. + if is_x_iterator or is_x_eager_iterator: + training_utils.validate_iterator_input(x, y, sample_weight, + validation_split) + + # For eager iterators, when we have to process multiple batches of samples, + # we will standardize the data when we actually loop over iterator and get + # the batches. For now, we just return the iterator as is. + if is_x_eager_iterator and steps is not None: + return x, y, sample_weight + + # If input data is a dataset iterator in graph mode or if it is an eager + # iterator and only one batch of samples is required, we fetch the data + # tensors from the iterator and then standardize them. + if is_x_iterator or is_x_eager_iterator: + try: + if is_x_iterator: + next_element = self._get_iterator_get_next_tensors(x) + else: + next_element = x.get_next() + except errors.OutOfRangeError: + raise RuntimeError('Your dataset iterator ran out of data; ' + 'Make sure that your dataset can generate ' + 'required number of samples.') + + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide data as a list or tuple of 2 elements ' + ' - input and target pair. Received %s' % next_element) + x, y = next_element + + # First, we build/compile the model on the fly if necessary. all_inputs = [] + is_build_called = False + is_compile_called = False if not self.built: # We need to use `x` to set the model inputs. # We type-check that `x` and `y` are either single arrays @@ -720,6 +769,7 @@ class Model(Network): # If values, then in symbolic-mode placeholders will be created # to match the value shapes. if not self.inputs: + is_build_called = True self._set_inputs(x) if y is not None: @@ -736,6 +786,7 @@ class Model(Network): raise ValueError('Please provide as model targets either a single ' 'array or a list of arrays. ' 'You passed: y=' + str(y)) + all_inputs += list(y) elif isinstance(y, dict): raise ValueError('Please do not pass a dictionary as model targets.') else: @@ -743,14 +794,10 @@ class Model(Network): raise ValueError('Please provide as model targets either a single ' 'array or a list of arrays. ' 'You passed: y=' + str(y)) + all_inputs.append(y) # Typecheck that all inputs are *either* value *or* symbolic. # TODO(fchollet): this check could be removed in Eager mode? - if y is not None: - if isinstance(y, (list, tuple)): - all_inputs += list(y) - else: - all_inputs.append(y) if any(tensor_util.is_tensor(v) for v in all_inputs): if not all(tensor_util.is_tensor(v) for v in all_inputs): raise ValueError('Do not pass inputs that mix Numpy arrays and ' @@ -764,17 +811,22 @@ class Model(Network): if not isinstance(y, (list, tuple)): y = [y] target_tensors = [v for v in y if tensor_util.is_tensor(v)] + is_compile_called = True self.compile(optimizer=self.optimizer, loss=self.loss, metrics=self.metrics, loss_weights=self.loss_weights, target_tensors=target_tensors) - # If `x` and `y` were all symbolic, then no model should not be fed any - # inputs and targets. + # In graph mode, if we had just set inputs and targets as symbolic tensors + # by invoking build and compile on the model respectively, we do not have to + # feed anything to the model. Model already has input and target data as + # part of the graph. # Note: in this case, `any` and `all` are equivalent since we disallow # mixed symbolic/value inputs. - if any(tensor_util.is_tensor(v) for v in all_inputs): + if (not context.executing_eagerly() and is_build_called and + is_compile_called and + any(tensor_util.is_tensor(v) for v in all_inputs)): return [], [], [] # What follows is input validation and standardization to list format, @@ -904,7 +956,12 @@ class Model(Network): if isinstance(inputs, list): assert len(inputs) == 1 inputs = inputs[0] - self.build(input_shape=(None,) + inputs.shape[1:]) + + if tensor_util.is_tensor(inputs): + input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) + else: + input_shape = (None,) + inputs.shape[1:] + self.build(input_shape=input_shape) elif context.executing_eagerly(): self._eager_set_inputs(inputs) else: @@ -931,12 +988,18 @@ class Model(Network): # On-the-fly setting of model inputs/outputs as DeferredTensors, # to keep track of number of inputs and outputs and their ndim. if isinstance(inputs, (list, tuple)): - dummy_output_values = self.call( - [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) + if tensor_util.is_tensor(inputs[0]): + dummy_output_values = self.call(inputs) + else: + dummy_output_values = self.call( + [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) dummy_input_values = list(inputs) else: - dummy_output_values = self.call( - ops.convert_to_tensor(inputs, dtype=K.floatx())) + if tensor_util.is_tensor(inputs): + dummy_output_values = self.call(inputs) + else: + dummy_output_values = self.call( + ops.convert_to_tensor(inputs, dtype=K.floatx())) dummy_input_values = [inputs] if isinstance(dummy_output_values, (list, tuple)): dummy_output_values = list(dummy_output_values) @@ -1071,7 +1134,7 @@ class Model(Network): batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. - Do not specify the `batch_size` is your data is in the + Do not specify the `batch_size` if your data is in the form of symbolic tensors or dataset iterators (since they generate batches). epochs: Integer. Number of epochs to train the model. @@ -1094,7 +1157,8 @@ class Model(Network): the loss and any model metrics on this data at the end of each epoch. The validation data is selected from the last samples - in the `x` and `y` data provided, before shuffling. + in the `x` and `y` data provided, before shuffling. This argument is + not supported when `x` is a dataset iterator. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. @@ -1124,7 +1188,8 @@ class Model(Network): `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - `sample_weight_mode="temporal"` in `compile()`. + `sample_weight_mode="temporal"` in `compile()`. This argument is not + supported when `x` is a dataset iterator. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). @@ -1165,21 +1230,23 @@ class Model(Network): epochs = kwargs.pop('nb_epoch') if kwargs: raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) - if x is None and y is None and steps_per_epoch is None: - raise ValueError('If fitting from data tensors, ' - 'you should specify the `steps_per_epoch` ' - 'argument.') - # Validate user data. + # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight, - batch_size=batch_size) + batch_size=batch_size, + check_steps=True, + steps_name='steps_per_epoch', + steps=steps_per_epoch, + validation_split=validation_split) + # Prepare validation data. if validation_data: - if isinstance(validation_data, iterator_ops.Iterator): + if (isinstance(validation_data, iterator_ops.Iterator) or + isinstance(validation_data, iterator_ops.EagerIterator)): val_x = validation_data val_y = None val_sample_weight = None @@ -1196,11 +1263,13 @@ class Model(Network): 'or alternatively it could be a dataset iterator. However we ' 'received `validation_data=%s`' % validation_data) + # Validate and standardize validation data. val_x, val_y, val_sample_weights = self._standardize_user_data( val_x, val_y, sample_weight=val_sample_weight, - batch_size=batch_size) + batch_size=batch_size, + steps=validation_steps) elif validation_split and 0. < validation_split < 1.: if training_utils.has_symbolic_tensors(x): @@ -1229,6 +1298,7 @@ class Model(Network): inputs=x, targets=y, sample_weights=sample_weights, + class_weight=class_weight, batch_size=batch_size, epochs=epochs, verbose=verbose, @@ -1300,7 +1370,8 @@ class Model(Network): `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - `sample_weight_mode="temporal"` in `compile()`. + `sample_weight_mode="temporal"` in `compile()`. This argument is not + supported when `x` is a dataset iterator. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. @@ -1318,17 +1389,16 @@ class Model(Network): # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 - if x is None and y is None and steps is None: - raise ValueError('If evaluating from data tensors, ' - 'you should specify the `steps` ' - 'argument.') - # Validate user data. + # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, - batch_size=batch_size) + batch_size=batch_size, + check_steps=True, + steps_name='steps', + steps=steps) if context.executing_eagerly(): return training_eager.test_loop( @@ -1345,7 +1415,12 @@ class Model(Network): Computation is done in batches. Arguments: - x: Input samples, as Numpy array(s) or tensor(s). + x: Input samples. It could be: + - A Numpy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A TensorFlow tensor, or a list of tensors + (in case the model has multiple inputs). + - A `tf.data` dataset iterator. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. @@ -1369,11 +1444,10 @@ class Model(Network): # Backwards compatibility. if batch_size is None and steps is None: batch_size = 32 - if x is None and steps is None: - raise ValueError('If predicting from data tensors, ' - 'you should specify the `steps` ' - 'argument.') - x, _, _ = self._standardize_user_data(x) + + # Validate and standardize user data. + x, _, _ = self._standardize_user_data( + x, check_steps=True, steps_name='steps', steps=steps) if context.executing_eagerly(): return training_eager.predict_loop( @@ -1406,7 +1480,9 @@ class Model(Network): with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). + sample_weight_mode="temporal" in compile(). This argument is not + supported when `x` is a dataset iterator. + class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples @@ -1424,11 +1500,9 @@ class Model(Network): Raises: ValueError: In case of invalid user-provided arguments. """ + # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( - x, - y, - sample_weight=sample_weight, - class_weight=class_weight) + x, y, sample_weight=sample_weight, class_weight=class_weight) if context.executing_eagerly(): outputs = training_eager.train_on_batch( @@ -1470,7 +1544,8 @@ class Model(Network): with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). + sample_weight_mode="temporal" in compile(). This argument is not + supported when `x` is a dataset iterator. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1481,6 +1556,7 @@ class Model(Network): Raises: ValueError: In case of invalid user-provided arguments. """ + # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight) @@ -1503,23 +1579,34 @@ class Model(Network): """Returns predictions for a single batch of samples. Arguments: - x: Input samples, as Numpy array(s) or tensor(s). + x: Input data. It could be: + - A Numpy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A TensorFlow tensor, or a list of tensors + (in case the model has multiple inputs). + - A `tf.data` dataset iterator. Returns: Numpy array(s) of predictions. + Raises: + ValueError: In case of mismatch between given number of inputs and + expectations of the model. """ - x, _, _ = self._standardize_user_data(x) - + # Validate and standardize user data. + inputs, _, _ = self._standardize_user_data(x) if context.executing_eagerly(): - inputs = [ops.convert_to_tensor(val, dtype=K.floatx()) for val in x] + if not isinstance(inputs, iterator_ops.EagerIterator): + inputs = [ + ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs + ] return self(inputs) # pylint: disable=not-callable if not context.executing_eagerly(): if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0] + ins = inputs + [0] else: - ins = x + ins = inputs self._make_predict_function() outputs = self.predict_function(ins) @@ -1631,8 +1718,7 @@ class Model(Network): steps_per_epoch=10000, epochs=10) ``` Raises: - ValueError: In case the generator yields - data in an invalid format. + ValueError: In case the generator yields data in an invalid format. """ if not self.built and not self._is_graph_network: raise NotImplementedError( @@ -1697,8 +1783,7 @@ class Model(Network): ValueError: in case of invalid arguments. Raises: - ValueError: In case the generator yields - data in an invalid format. + ValueError: In case the generator yields data in an invalid format. """ if not self.built and not self._is_graph_network: raise NotImplementedError( @@ -1751,8 +1836,7 @@ class Model(Network): Numpy array(s) of predictions. Raises: - ValueError: In case the generator yields - data in an invalid format. + ValueError: In case the generator yields data in an invalid format. """ if not self.built and not self._is_graph_network: raise NotImplementedError( diff --git a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py index 4164cae864..12e74ef51d 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_arrays.py @@ -108,8 +108,8 @@ def fit_loop(model, do_validation = False if val_inputs: do_validation = True - if verbose and inputs and hasattr(inputs[0], 'shape') and hasattr( - val_inputs[0], 'shape'): + if (steps_per_epoch is None and verbose and inputs and + hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): print('Train on %d samples, validate on %d samples' % (inputs[0].shape[0], val_inputs[0].shape[0])) if validation_steps: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index b9c99b2222..3617eb281a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -23,7 +23,9 @@ import copy import numpy as np +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager.backprop import GradientTape +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend @@ -177,6 +179,550 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): return outs, total_loss, loss_metrics +def iterator_fit_loop(model, + inputs, + class_weight, + steps_per_epoch, + callback_model, + out_labels, + epoch_logs, + val_inputs=None, + val_targets=None, + val_sample_weights=None, + epochs=1, + verbose=1, + callbacks=None, + callback_metrics=None, + validation_steps=None, + do_validation=False): + """Fit function for eager execution when input is given as dataset iterator. + + Updates the given epoch logs. + + Arguments: + model: Instance of the `Model`. + inputs: Input dataset iterator. + class_weight: Optional class-weight array to weight the importance of + samples in `inputs` based on the class they belong to, as conveyed by + the targets from the `inputs` iterator. + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. + callback_model: Instance of `Model` to callback. + out_labels: Output labels generated from model metric names. + epoch_logs: Dictionary of logs from every epoch. + val_inputs: Input data for validation. + val_targets: Target data for validation. + val_sample_weights: Sample weight data for validation. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training + callback_metrics: List of strings, the display names of the metrics + passed to the callbacks. They should be the + concatenation of list the display names of the outputs of + `f` and the list of display names of the outputs of `f_val`. + validation_steps: Number of steps to run validation for (only if doing + validation from data tensors). Ignored with default value of `None`. + do_validation: Boolean value indicating whether we should do validation. + + Raises: + ValueError: In case of mismatch between given number of inputs and + expectations of the model. + """ + assert isinstance(inputs, iterator_ops.EagerIterator) + for step_index in range(steps_per_epoch): + batch_logs = {} + batch_logs['batch'] = step_index + batch_logs['size'] = 1 + callbacks.on_batch_begin(step_index, batch_logs) + + # Get data from the iterator. + try: + next_element = inputs.get_next() + except errors.OutOfRangeError: + logging.warning( + 'Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset' + ' can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % steps_per_epoch * epochs) + break + + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide data as a list or tuple of 2 elements ' + ' - input and target pair. Received %s' % next_element) + x, y = next_element + + # Validate and standardize data. + x, y, sample_weights = model._standardize_user_data( + x, y, class_weight=class_weight) + if sample_weights: + sample_weights = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights + ] + + if step_index == 0 and not callback_metrics: + out_labels = model.metrics_names + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + callback_metrics = copy.copy(out_labels) + callbacks.set_params({ + 'epochs': epochs, + 'steps': steps_per_epoch, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics or [], + }) + + # Train model. + outs, loss, loss_metrics = _process_single_batch( + model, x, y, sample_weights=sample_weights, training=True) + if not isinstance(outs, list): + outs = [outs] + + # Calculate metrics. + for l, o in zip(out_labels, outs): + batch_logs[l] = o + # Required for eager execution + metrics_results = _eager_metrics_fn(model, outs, y) + batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) + + for k, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_logs[k] = tensor_util.constant_value(v) + callbacks.on_batch_end(step_index, batch_logs) + if callback_model.stop_training: + break + + if step_index == steps_per_epoch - 1: + if do_validation: + val_outs = test_loop( + model, + val_inputs, + val_targets, + sample_weights=val_sample_weights, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + + +def batch_fit_loop(model, + inputs, + targets, + epoch_logs, + index_array, + out_labels, + callback_model, + batch_size, + sample_weights=None, + val_inputs=None, + val_targets=None, + val_sample_weights=None, + callbacks=None, + shuffle=True, + num_train_samples=None, + do_validation=False): + """Fit function for eager execution when input is given as arrays or tensors. + + Updates the given epoch logs. + + Arguments: + model: Instance of the `Model`. + inputs: List of input arrays. + targets: List of target arrays. + epoch_logs: Dictionary of logs from every epoch. + index_array: Index array generated from number of training samples. + out_labels: Output labels generated from model metric names. + callback_model: Instance of `Model` to callback. + batch_size: Integer batch size or None if unknown. + sample_weights: Optional list of sample weight arrays. + val_inputs: Input data for validation. + val_targets: Target data for validation. + val_sample_weights: Sample weight data for validation. + callbacks: List of callbacks to be called during training. + shuffle: Whether to shuffle the data at the beginning of each epoch. + num_train_samples: Integer number of training samples. + do_validation: Boolean value indicating whether we should do validation. + """ + # TODO(psv): Create a dataset iterator instead of manually creating batches + # here and in batch_test_loop, batch_predict_loop. + if shuffle == 'batch': + index_array = model._batch_shuffle(index_array, batch_size) + elif shuffle: + np.random.shuffle(index_array) + + batches = generic_utils.make_batches(num_train_samples, batch_size) + + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch = slice_arrays(inputs, batch_ids, contiguous=not shuffle) + targets_batch = slice_arrays(targets, batch_ids, contiguous=not shuffle) + if sample_weights: + sample_weights_batch = slice_arrays( + sample_weights, batch_ids, contiguous=not shuffle) + else: + sample_weights_batch = None + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = len(batch_ids) + + callbacks.on_batch_begin(batch_index, batch_logs) + + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch + ] + targets_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in targets_batch + ] + if sample_weights: + sample_weights_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights_batch + ] + + outs, loss, loss_metrics = _process_single_batch( + model, + inputs_batch, + targets_batch, + sample_weights=sample_weights_batch, + training=True) + + if not isinstance(outs, list): + outs = [outs] + + for l, o in zip(out_labels, outs): + batch_logs[l] = o + # Required for eager execution + metrics_results = _eager_metrics_fn(model, outs, targets_batch) + batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) + + for k, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_logs[k] = tensor_util.constant_value(v) + callbacks.on_batch_end(batch_index, batch_logs) + if callback_model.stop_training: + break + + if batch_index == len(batches) - 1: # Last batch. + if do_validation: + val_outs = test_loop( + model, + val_inputs, + val_targets, + sample_weights=val_sample_weights, + batch_size=batch_size, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + + +def iterator_test_loop(model, inputs, steps, verbose=0): + """Test function for eager execution when input is given as dataset iterator. + + Arguments: + model: Model instance that is being evaluated in Eager mode. + inputs: Input dataset iterator. + steps: Total number of steps (batches of samples) before declaring + predictions finished. + verbose: Verbosity mode. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + + Raises: + ValueError: In case of mismatch between given number of inputs and + expectations of the model. + """ + assert isinstance(inputs, iterator_ops.EagerIterator) + outs = [] + num_samples = 0 + if verbose == 1: + progbar = generic_utils.Progbar(target=steps) + for step_index in range(steps): + # Get data from the iterator. + try: + next_element = inputs.get_next() + except errors.OutOfRangeError: + logging.warning( + 'Your dataset iterator ran out of data interrupting testing. ' + 'Make sure that your dataset can generate at least `steps` batches ' + '(in this case, %d batches).', steps) + break + + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide data as a list or tuple of 2 elements ' + ' - input and target pair. Received %s' % next_element) + x, y = next_element + + # Validate and standardize data. + x, y, sample_weights = model._standardize_user_data(x, y) + + # Calculate model output, loss values. + loss_outs, loss, loss_metrics = _model_loss( + model, x, y, sample_weights=sample_weights, training=False) + metrics_results = _eager_metrics_fn(model, loss_outs, y) + batch_outs = [] + for _, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_outs.append(tensor_util.constant_value(v)) + + # Get current step size. + if isinstance(x, list): + step_size = x[0].get_shape().as_list()[0] + else: + step_size = x.get_shape().as_list()[0] + + # Accumulate results in output array. + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if step_index == 0: + for _ in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out * step_size + + # Calculate sample size. + num_samples += step_size + if verbose == 1: + progbar.update(step_index + 1) + + for i in range(len(outs)): + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs + + +def batch_test_loop(model, + inputs, + targets, + batch_size, + sample_weights=None, + verbose=0): + """Test function for eager execution when input is given as arrays or tensors. + + Arguments: + model: Model instance that is being evaluated in Eager mode. + inputs: List of input arrays. + targets: List of target arrays. + batch_size: Integer batch size. + sample_weights: Optional list of sample weight arrays. + verbose: Verbosity mode. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + outs = [] + feed_data = inputs + targets + if sample_weights: + feed_data += sample_weights + num_samples = training_utils.check_num_samples( + feed_data, batch_size=batch_size) + if verbose == 1: + progbar = generic_utils.Progbar(target=num_samples) + batches = generic_utils.make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch = slice_arrays(inputs, batch_ids) + targets_batch = slice_arrays(targets, batch_ids) + if sample_weights: + sample_weights_batch = slice_arrays(sample_weights, batch_ids) + else: + sample_weights_batch = None + + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch + ] + targets_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in targets_batch + ] + if sample_weights: + sample_weights_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights_batch + ] + + loss_outs, loss, loss_metrics = _model_loss( + model, + inputs_batch, + targets_batch, + sample_weights=sample_weights_batch, + training=False) + metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) + batch_outs = [] + for _, v in zip(model.metrics_names, + [backend.mean(loss)] + loss_metrics + metrics_results): + batch_outs.append(tensor_util.constant_value(v)) + + if isinstance(batch_outs, list): + if batch_index == 0: + for _ in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out * len(batch_ids) + else: + if batch_index == 0: + outs.append(0.) + outs[0] += batch_outs * len(batch_ids) + + if verbose == 1: + progbar.update(batch_end) + + for i in range(len(outs)): + outs[i] /= num_samples + if len(outs) == 1: + return outs[0] + return outs + + +def iterator_predict_loop(model, inputs, steps, verbose=0): + """Predict function for eager execution when input is dataset iterator. + + Arguments: + model: Instance of `Model`. + inputs: Input dataset iterator. + steps: Total number of steps (batches of samples) before declaring + `_predict_loop` finished. + verbose: Verbosity mode. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions (if the model has multiple outputs). + + Raises: + ValueError: In case of mismatch between given number of inputs and + expectations of the model. + """ + assert isinstance(inputs, iterator_ops.EagerIterator) + outs = [] + if verbose == 1: + progbar = generic_utils.Progbar(target=steps) + for step_index in range(steps): + # Get data from the iterator. + try: + next_element = inputs.get_next() + except errors.OutOfRangeError: + logging.warning( + 'Your dataset iterator ran out of data; ' + 'interrupting prediction. Make sure that your ' + 'dataset can generate at least `steps` ' + 'batches (in this case, %d batches).', steps) + break + + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError( + 'Please provide data as a list or tuple of 2 elements ' + ' - input and target pair. Received %s. We do not use the ' + '`target` value here.' % next_element) + x, _ = next_element + + # Validate and standardize data. + x, _, _ = model._standardize_user_data(x) + + if model._expects_training_arg: + batch_outs = model.call(x[0] if len(x) == 1 else x, training=False) + else: + batch_outs = model.call(x[0] if len(x) == 1 else x) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + + # We collect the results from every step and then concatenate them once + # in the end. This is an expensive process. We are doing this because we + # do not know the number of samples beforehand. + if step_index == 0: + for _ in batch_outs: + outs.append([]) + for i, batch_out in enumerate(batch_outs): + outs[i].append(backend.get_value(batch_out)) + + if verbose == 1: + progbar.update(step_index + 1) + for i, out in enumerate(outs): + outs[i] = np.concatenate(tuple(out), axis=0) + if len(outs) == 1: + return outs[0] + return outs + + +def batch_predict_loop(model, inputs, batch_size, verbose=0): + """Predict function for eager execution when input is arrays or tensors. + + Arguments: + model: Instance of `Model`. + inputs: List of input arrays. + batch_size: Integer batch size. + verbose: Verbosity mode. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions (if the model has multiple outputs). + """ + outs = [] + num_samples = training_utils.check_num_samples(inputs, batch_size) + if verbose == 1: + progbar = generic_utils.Progbar(target=num_samples) + batches = generic_utils.make_batches(num_samples, batch_size) + index_array = np.arange(num_samples) + for batch_index, (batch_start, batch_end) in enumerate(batches): + batch_ids = index_array[batch_start:batch_end] + inputs_batch = slice_arrays(inputs, batch_ids) + + inputs_batch = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + for val in inputs_batch + ] + + if len(inputs_batch) == 1: + if model._expects_training_arg: + batch_outs = model.call(inputs_batch[0], training=False) + else: + batch_outs = model.call(inputs_batch[0]) + else: + if model._expects_training_arg: + batch_outs = model.call(inputs_batch, training=False) + else: + batch_outs = model.call(inputs_batch) + + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if batch_index == 0: + # Pre-allocate the results arrays. + for batch_out in batch_outs: + dims = batch_out.shape[1:].dims + dims_list = [d.value for d in dims] + shape = (num_samples,) + tuple(dims_list) + outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) + for i, batch_out in enumerate(batch_outs): + outs[i][batch_start:batch_end] = batch_out + if verbose == 1: + progbar.update(batch_end) + + if len(outs) == 1: + return outs[0] + return outs + + def slice_arrays(arrays, indices, contiguous=True): """Slices batches out of provided arrays (workaround for eager tensors). @@ -268,19 +814,24 @@ def train_on_batch(model, inputs, targets, sample_weights=None): Returns: total loss and the loss associated with each output. """ - inputs = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs] - targets = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets] - sample_weights = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None for val in sample_weights] + if len(inputs) and not tensor_util.is_tensor(inputs[0]): + inputs = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs + ] + targets = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets + ] + if sample_weights: + sample_weights = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights + ] + outs, loss, _ = _process_single_batch( model, inputs, targets, sample_weights=sample_weights, training=True) if not isinstance(outs, list): outs = [outs] - metrics_results = _eager_metrics_fn( - model, outs, targets) + metrics_results = _eager_metrics_fn(model, outs, targets) if not isinstance(loss, list): loss = [loss] return loss + metrics_results @@ -298,48 +849,55 @@ def test_on_batch(model, inputs, targets, sample_weights=None): Returns: total loss, loss and metrics associated with each output. """ - inputs = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs] - targets = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets] - sample_weights = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None for val in sample_weights] - outs, loss, loss_metrics = _process_single_batch( + if len(inputs) and not tensor_util.is_tensor(inputs[0]): + inputs = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs + ] + targets = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets + ] + if sample_weights: + sample_weights = [ + ops.convert_to_tensor(val, dtype=backend.floatx()) + if val is not None else None for val in sample_weights + ] + outs, loss, loss_metrics = _model_loss( model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] - metrics_results = _eager_metrics_fn( - model, outs, targets) + metrics_results = _eager_metrics_fn(model, outs, targets) if not isinstance(loss, list): loss = [loss] return loss + loss_metrics + metrics_results -def fit_loop( - model, - inputs, - targets, - sample_weights=None, - val_inputs=None, - val_targets=None, - val_sample_weights=None, - batch_size=None, - epochs=100, - verbose=1, - callbacks=None, - shuffle=True, - callback_metrics=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None): - """Abstract fit function for eager execution. +def fit_loop(model, + inputs, + targets, + sample_weights=None, + class_weight=None, + val_inputs=None, + val_targets=None, + val_sample_weights=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + shuffle=True, + callback_metrics=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None): + """Fit function for eager execution. Arguments: model: Instance of the model that is being executed in Eager mode. inputs: List of input arrays. targets: List of target arrays. sample_weights: Optional list of sample weight arrays. + class_weight: Optional class-weight array to weight the importance of + samples in `inputs` based on the class they belong to, as conveyed by + `targets`. val_inputs: Input data for validation. val_targets: Target data for validation. val_sample_weights: Sample weight data for validation. @@ -366,47 +924,40 @@ def fit_loop( Raises: ValueError: In case of invalid argument values. """ - if not batch_size: - raise ValueError('With eager execution, `batch_size` should be specified.') - if steps_per_epoch or validation_steps: - raise ValueError('With eager execution, `steps_per_epoch` and ' - '`validation_steps` are not valid arguments ' - '(set `batch_size` instead).') - # Required for Eager mode + # Required for eager execution with backend.learning_phase_scope(1): do_validation = False if val_inputs: do_validation = True - if (verbose and inputs and hasattr(inputs[0], 'shape') and - hasattr(val_inputs[0], 'shape')): + if (steps_per_epoch is None and verbose and inputs and + hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): print('Train on %d samples, validate on %d samples' % (inputs[0].shape[0], val_inputs[0].shape[0])) - if validation_steps: - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` when doing step-wise ' - 'training, i.e. `steps_per_epoch` must be set.') - do_validation = True - out_labels = model.metrics_names - if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - callback_metrics = copy.copy(out_labels) + num_train_samples = None + out_labels = None + if steps_per_epoch is None or model._is_compiled: + out_labels = model.metrics_names + if do_validation: + callback_metrics = copy.copy(out_labels) + [ + 'val_' + n for n in out_labels + ] + else: + callback_metrics = copy.copy(out_labels) - if sample_weights: - feed_data = inputs + targets + sample_weights - else: - feed_data = inputs + targets - num_train_samples = training_utils.check_num_samples( - feed_data, - batch_size=batch_size, - steps=steps_per_epoch, - steps_name='steps_per_epoch') + if steps_per_epoch is None: + if sample_weights: + feed_data = inputs + targets + sample_weights + else: + feed_data = inputs + targets + num_train_samples = training_utils.check_num_samples( + feed_data, + batch_size=batch_size, + steps=steps_per_epoch, + steps_name='steps_per_epoch') - if num_train_samples is not None: - index_array = np.arange(num_train_samples) + if num_train_samples is not None: + index_array = np.arange(num_train_samples) model.history = cbks.History() callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] @@ -441,6 +992,8 @@ def fit_loop( for cbk in callbacks: if not val_inputs: cbk.validation_data = [] + elif isinstance(val_inputs, iterator_ops.EagerIterator): + cbk.validation_data = val_inputs elif val_sample_weights: cbk.validation_data = val_inputs + val_targets + val_sample_weights else: @@ -449,87 +1002,48 @@ def fit_loop( for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) epoch_logs = {} - if shuffle == 'batch': - index_array = model._batch_shuffle(index_array, batch_size) - elif shuffle: - np.random.shuffle(index_array) - - batches = generic_utils.make_batches(num_train_samples, batch_size) - - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - try: - inputs_batch = slice_arrays(inputs, batch_ids, - contiguous=not shuffle) - targets_batch = slice_arrays(targets, batch_ids, - contiguous=not shuffle) - if sample_weights: - sample_weights_batch = slice_arrays(sample_weights, batch_ids, - contiguous=not shuffle) - else: - sample_weights_batch = None - except TypeError: - raise TypeError('TypeError while preparing batch. ' - 'If using HDF5 input data, ' - 'pass shuffle="batch".') - batch_logs = {} - batch_logs['batch'] = batch_index - batch_logs['size'] = len(batch_ids) - - callbacks.on_batch_begin(batch_index, batch_logs) - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch] - targets_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in targets_batch] - if sample_weights: - sample_weights_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None - for val in sample_weights_batch] - - outs, loss, loss_metrics = _process_single_batch( + + if steps_per_epoch is not None: + iterator_fit_loop( model, - inputs_batch, - targets_batch, - sample_weights=sample_weights_batch, - training=True) - - if not isinstance(outs, list): - outs = [outs] - - for l, o in zip(out_labels, outs): - batch_logs[l] = o - # Required for Eager mode - metrics_results = _eager_metrics_fn(model, outs, targets_batch) - batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) - - for k, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): - batch_logs[k] = tensor_util.constant_value(v) - callbacks.on_batch_end(batch_index, batch_logs) - if callback_model.stop_training: - break - - if batch_index == len(batches) - 1: # Last batch. - if do_validation: - val_outs = test_loop( - model, val_inputs, val_targets, - sample_weights=val_sample_weights, - batch_size=batch_size, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o + inputs, + class_weight, + steps_per_epoch=steps_per_epoch, + callback_model=callback_model, + out_labels=out_labels, + epoch_logs=epoch_logs, + val_inputs=val_inputs, + val_targets=val_targets, + val_sample_weights=val_sample_weights, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + callback_metrics=callback_metrics, + validation_steps=validation_steps, + do_validation=do_validation) + else: + batch_fit_loop( + model, + inputs, + targets, + epoch_logs=epoch_logs, + index_array=index_array, + out_labels=out_labels, + callback_model=callback_model, + batch_size=batch_size, + sample_weights=sample_weights, + val_inputs=val_inputs, + val_targets=val_targets, + val_sample_weights=val_sample_weights, + callbacks=callbacks, + shuffle=shuffle, + num_train_samples=num_train_samples, + do_validation=do_validation) callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break - callbacks.on_train_end() - return model.history + callbacks.on_train_end() + return model.history def test_loop(model, inputs, targets, @@ -537,7 +1051,7 @@ def test_loop(model, inputs, targets, batch_size=None, verbose=0, steps=None): - """Abstract method to loop over some data in batches. + """Test function for eager execution. Arguments: model: Model instance that is being evaluated in Eager mode. @@ -557,77 +1071,26 @@ def test_loop(model, inputs, targets, the display labels for the scalar outputs. """ with backend.learning_phase_scope(0): - feed_data = inputs + targets - if sample_weights: - feed_data += sample_weights - num_samples = training_utils.check_num_samples( - feed_data, batch_size=batch_size, steps=steps, steps_name='steps') - outs = [] - if verbose == 1: - progbar = generic_utils.Progbar(target=num_samples) - batches = generic_utils.make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - inputs_batch = slice_arrays(inputs, batch_ids) - targets_batch = slice_arrays(targets, batch_ids) - if sample_weights: - sample_weights_batch = slice_arrays(sample_weights, batch_ids) - else: - sample_weights_batch = None - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch] - targets_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in targets_batch] - if sample_weights: - sample_weights_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None - for val in sample_weights_batch] - - loss_outs, loss, loss_metrics = _model_loss( + if steps is not None: + return iterator_test_loop(model, inputs, steps, verbose=verbose) + else: + return batch_test_loop( model, - inputs_batch, - targets_batch, - sample_weights=sample_weights_batch, - training=False) - metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) - batch_outs = [] - for _, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): - batch_outs.append(tensor_util.constant_value(v)) - - if isinstance(batch_outs, list): - if batch_index == 0: - for batch_out in enumerate(batch_outs): - outs.append(0.) - for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * len(batch_ids) - else: - if batch_index == 0: - outs.append(0.) - outs[0] += batch_outs * len(batch_ids) - - if verbose == 1: - progbar.update(batch_end) - for i in range(len(outs)): - outs[i] /= num_samples - if len(outs) == 1: - return outs[0] - return outs + inputs, + targets, + batch_size=batch_size, + sample_weights=sample_weights, + verbose=verbose) def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): - """Abstract method to loop over some data in batches. + """Predict function for eager execution. Arguments: - model: + model: Instance of `Model`. inputs: List of input arrays. batch_size: integer batch size. verbose: verbosity mode. @@ -641,49 +1104,8 @@ def predict_loop(model, inputs, (if the model has multiple outputs). """ with backend.learning_phase_scope(0): - num_samples = training_utils.check_num_samples( - inputs, batch_size, steps, 'steps') - if verbose == 1: - if steps is not None: - progbar = generic_utils.Progbar(target=steps) - else: - progbar = generic_utils.Progbar(target=num_samples) - - outs = [] - batches = generic_utils.make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - inputs_batch = slice_arrays(inputs, batch_ids) - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch] - - if len(inputs_batch) == 1: - if model._expects_training_arg: - batch_outs = model.call(inputs_batch[0], training=False) - else: - batch_outs = model.call(inputs_batch[0]) - else: - if model._expects_training_arg: - batch_outs = model.call(inputs_batch, training=False) - else: - batch_outs = model.call(inputs_batch) - - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if batch_index == 0: - # Pre-allocate the results arrays. - for batch_out in batch_outs: - dims = batch_out.shape[1:].dims - dims_list = [d.value for d in dims] - shape = (num_samples,) + tuple(dims_list) - outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) - for i, batch_out in enumerate(batch_outs): - outs[i][batch_start:batch_end] = batch_out - if verbose == 1: - progbar.update(batch_end) - if len(outs) == 1: - return outs[0] - return outs + if steps is not None: + return iterator_predict_loop(model, inputs, steps, verbose=verbose) + else: + return batch_predict_loop( + model, inputs, batch_size=batch_size, verbose=verbose) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index 58011a1412..cc2386a5bd 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -24,6 +24,7 @@ import unittest import numpy as np from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras @@ -1340,16 +1341,12 @@ class TestTrainingWithDataTensors(test.TestCase): output_a_np) # test fit - out = model.fit(None, - output_a_np, epochs=1, batch_size=10) - out = model.fit(None, - output_a_np, epochs=1, batch_size=10) + _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=3) + _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=3) # test evaluate - out = model.evaluate(None, - output_a_np, batch_size=10) - out = model.evaluate(None, - output_a_np, batch_size=10) + _ = model.evaluate(None, output_a_np, steps=3) + _ = model.evaluate(None, output_a_np, steps=3) # test predict out = model.predict(None, steps=3) @@ -1383,16 +1380,12 @@ class TestTrainingWithDataTensors(test.TestCase): output_a_np) # test fit - out = model.fit(None, - output_a_np, epochs=1, batch_size=10) - out = model.fit(None, - output_a_np, epochs=1, batch_size=10) + _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=10) + _ = model.fit(None, output_a_np, epochs=1, steps_per_epoch=10) # test evaluate - out = model.evaluate(None, - output_a_np, batch_size=10) - out = model.evaluate(None, - output_a_np, batch_size=10) + _ = model.evaluate(None, output_a_np, steps=10) + _ = model.evaluate(None, output_a_np, steps=10) # test predict out = model.predict(None, steps=3) @@ -1715,40 +1708,56 @@ class TestTrainingWithDataTensors(test.TestCase): class TestTrainingWithDatasetIterators(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_training_and_eval_methods_on_iterators_single_io(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) - optimizer = 'rmsprop' + optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((10, 3)) - targets = np.zeros((10, 4)) + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) dataset = dataset.batch(10) iterator = dataset.make_one_shot_iterator() - model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0) - model.evaluate(iterator, steps=2, verbose=0) + model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(iterator, steps=2, verbose=1) model.predict(iterator, steps=2) model.train_on_batch(iterator) model.test_on_batch(iterator) + model.predict_on_batch(iterator) + # Test with validation data model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0, validation_data=iterator, validation_steps=2) # Test with validation split - with self.assertRaisesRegexp(ValueError, - 'you cannot use `validation_split`'): + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not supported ' + 'when input `x` is a dataset iterator'): model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=0, validation_split=0.5, validation_steps=2) + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported ' + 'when input `x` is a dataset iterator'): + model.fit( + iterator, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + # Test invalid usage with self.assertRaisesRegexp(ValueError, 'Instead, pass an `Iterator`'): @@ -1759,19 +1768,54 @@ class TestTrainingWithDatasetIterators(test.TestCase): model.fit(iterator, iterator, epochs=1, steps_per_epoch=2, verbose=0) + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(iterator, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(iterator, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(iterator, verbose=0) + + def test_get_next_op_created_once(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + iterator = dataset.make_one_shot_iterator() + + model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) + # Finalize graph to make sure we are not appending another iterator + # get_next op in the graph. + ops.get_default_graph().finalize() + model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) + + @tf_test_util.run_in_graph_and_eager_modes() def test_iterators_running_out_of_data(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) - optimizer = 'rmsprop' + optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae'] model.compile(optimizer, loss, metrics=metrics) - inputs = np.zeros((10, 3)) - targets = np.zeros((10, 4)) + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(2) dataset = dataset.batch(10) diff --git a/tensorflow/python/keras/_impl/keras/engine/training_utils.py b/tensorflow/python/keras/_impl/keras/engine/training_utils.py index 662938f421..04d80c891f 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_utils.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_utils.py @@ -22,6 +22,7 @@ import copy import numpy as np +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K @@ -65,14 +66,7 @@ def check_num_samples(ins, if steps is not None and batch_size is not None: raise ValueError( 'If ' + steps_name + ' is set, the `batch_size` must be None.') - - if not ins or has_symbolic_tensors(ins): - if steps is None: - raise ValueError('If your data is in the form of symbolic tensors, ' - 'you should specify the `' + steps_name + '` argument ' - '(instead of the `batch_size` argument, ' - 'because symbolic tensors are expected to produce ' - 'batches of input data).') + if check_steps_argument(ins, steps, steps_name): return None if hasattr(ins[0], 'shape'): return int(ins[0].shape[0]) @@ -551,8 +545,11 @@ def standardize_weights(y, def has_symbolic_tensors(ls): - return (any(tensor_util.is_tensor(v) for v in ls) - and not context.executing_eagerly()) + if context.executing_eagerly(): + return False + if isinstance(ls, (list, tuple)): + return any(tensor_util.is_tensor(v) for v in ls) + return tensor_util.is_tensor(ls) def populate_metric_names(model): @@ -614,3 +611,77 @@ def add_metric_name(model, metric_name, index): metric_name = '%s_%d' % (base_metric_name, j) j += 1 model.metrics_names.append(metric_name) + + +def validate_iterator_input(x, y, sample_weight, validation_split=None): + """Validates user input arguments when a dataset iterator is passed. + + Arguments: + x: Input data. A `tf.data` dataset iterator. + y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). + Expected to be `None` when `x` is a dataset iterator. + sample_weight: An optional sample-weight array passed by the user to + weight the importance of each sample in `x`. Expected to be `None` when + `x` is a dataset iterator + validation_split: Float between 0 and 1. Fraction of the training data to + be used as validation data. Expected to be `None` when `x` is a dataset + iterator. + + Raises: + ValueError: if argument `y` or `sample_weight` or `validation_split` are + provided by user. + """ + if y is not None: + raise ValueError('You passed a dataset iterator (%s) as input `x` to ' + 'your model. In that case, you should not specify ' + 'a target (`y`) argument, since the dataset iterator ' + 'generates both input data and target data. ' + 'Received: %s' % (x, y)) + if sample_weight is not None: + raise ValueError('`sample_weight` argument is not supported when input' + ' `x` is a dataset iterator. ' + 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) + if validation_split is not None and validation_split != 0.0: + raise ValueError( + '`validation_split` argument is not supported when ' + 'input `x` is a dataset iterator. ' + 'Received: x=%s, validation_split=%f' % (x, validation_split)) + + +def check_steps_argument(input_data, steps, steps_name): + """Validates `steps` argument based on input data's type. + + The cases when `steps` value must be provided are when + 1. input data passed is an iterator. + 2. model was built on top of symbolic tensors, input data is not + required and is `None`. + 3. input data passed is a symbolic tensor. + + Arguments: + input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or + tf.data.Dataset iterator or `None`. + steps: Integer or `None`. Total number of steps (batches of samples) to + execute. + steps_name: The public API's parameter name for `steps`. + + Returns: + boolean, True if `steps` argument is required, else False. + + Raises: + ValueError: if `steps` argument is required for given input data type + but not provided. + """ + + is_x_iterator = ( + isinstance(input_data, iterator_ops.Iterator) or + isinstance(input_data, iterator_ops.EagerIterator)) + + if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or + (isinstance(input_data, list) and not input_data)): + if steps is None: + input_type_str = 'iterators' if is_x_iterator else 'data tensors' + raise ValueError('When using {input_type} as input to a model, you should' + ' specify the `{steps_name}` argument.'.format( + input_type=input_type_str, steps_name=steps_name)) + return True + return False diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py index 295ad47f6b..1e88dc09fb 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -23,12 +23,15 @@ import os import numpy as np import six +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training import checkpointable from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -248,6 +251,26 @@ class ModelSubclassingTest(test.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) _ = model.evaluate(steps=10, verbose=0) + @test_util.run_in_graph_and_eager_modes() + def test_single_io_workflow_with_dataset_iterators(self): + num_classes = 2 + num_samples = 10 + input_dim = 50 + + with self.test_session(): + model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + iterator = dataset.make_one_shot_iterator() + + model.fit(iterator, epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(iterator, steps=10, verbose=0) + def test_multi_io_workflow_with_numpy_arrays_and_custom_placeholders(self): num_classes = (2, 3) @@ -583,6 +606,22 @@ class ModelSubclassingTest(test.TestCase): loss = model.train_on_batch(x, y) self.assertGreater(loss, 0.1) + def test_no_dependency(self): + class Foo(keras.Model): + + def __init__(self): + super(Foo, self).__init__() + self.isdep = keras.layers.Dense(1) + self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) + self.notdep_var = checkpointable.NoDependency( + resource_variable_ops.ResourceVariable(1., name='notdep_var')) + + m = Foo() + self.assertEqual([m.isdep, m.notdep], m.layers) + self.assertEqual(1, len(m._checkpoint_dependencies)) + self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref) + self.assertEqual('notdep_var:0', m.notdep_var.name) + class CustomCallModel(keras.Model): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 77e6f5f1a0..843759fed0 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1847,6 +1847,23 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) + def testGradInWhileWrtInitialLoopVal(self): + with self.test_session(): + x = array_ops.placeholder(dtypes.float32, shape=(), name="x") + y = x + 1 + + def body(i, v): + z = v * 2 + return i + 1, gradients_impl.gradients(z, x)[0] + + with self.assertRaisesRegexp( + ValueError, + "Cannot compute gradient inside while loop with respect to op 'x'. " + "We do not support taking the gradient wrt or through the initial " + "value of a loop variable. Gradients can be computed through " + "loop invariants or wrt the input parameters to the loop body."): + control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y]) + def testWhileGradInWhile(self): with self.test_session(): n = ops.convert_to_tensor(1.0, name="n") diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py index b692d3da60..27804be65c 100644 --- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py +++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py @@ -23,6 +23,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops @@ -292,6 +293,7 @@ class Conv2DTransposeTest(test.TestCase): self.assertAllClose(cache_values, value) + @test_util.enable_c_shapes def testConv2DTransposeShapeInference(self): # Test case for 8972 initializer = random_ops.truncated_normal( @@ -301,7 +303,8 @@ class Conv2DTransposeTest(test.TestCase): f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5]) output = nn_ops.conv2d_transpose( x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME") - self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5]) + self.assertEqual(output.get_shape().as_list(), [3, 10, 5, 5]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 5ec80b95ee..dc465c867f 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -147,6 +147,32 @@ class AssertCloseTest(test.TestCase): array_ops.identity(w).eval(feed_dict=feed_dict) +class MaybeGetStaticTest(test.TestCase): + + def testGetStaticInt(self): + x = 2 + self.assertEqual(x, du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticNumpyArray(self): + x = np.array(2, dtype=np.int32) + self.assertEqual(x, du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticConstant(self): + x = constant_op.constant(2, dtype=dtypes.int32) + self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x)) + self.assertAllClose( + np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) + + def testGetStaticPlaceholder(self): + x = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + self.assertEqual(None, du.maybe_get_static_value(x)) + self.assertEqual(None, du.maybe_get_static_value(x, dtype=np.float64)) + + @test_util.with_c_api class GetLogitsAndProbsTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 052f11f92e..91be80322c 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -85,7 +85,10 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, - tags = ["noasan"], # times out b/63678675 + tags = [ + "noasan", # times out, b/63678675 + "optonly", # times out, b/79171797 + ], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 098f9724a2..49855200c2 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -43,6 +43,7 @@ def scalar_shape(): return ops.convert_to_tensor([], dtype=dtypes.int32) +@test_util.with_c_shapes class ListOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 984192258c..3daf07ea63 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -400,6 +400,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.var_is_initialized_op(abc.handle)), True) + def testScatterBool(self): + with context.eager_mode(): + ref = resource_variable_ops.ResourceVariable( + [False, True, False], trainable=False) + indices = math_ops.range(3) + updates = constant_op.constant([True, True, True]) + state_ops.scatter_update(ref, indices, updates) + self.assertAllEqual(ref.read_value(), [True, True, True]) + @test_util.run_in_graph_and_eager_modes() def testConstraintArg(self): constraint = lambda x: x diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 918bbd38ed..c0b36f143d 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -438,7 +438,6 @@ class TensorArrayTest(test.TestCase): "Tried to read from index 3 but array size is: 3"): self.evaluate(ta.read(3)) - @test_util.run_in_graph_and_eager_modes() def testTensorArrayWriteMultipleFails(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 306055d202..cabc1e724c 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1169,19 +1169,35 @@ def _assert_same_base_type(items, expected_type=None): Raises: ValueError: If any types do not match. """ - original_item_str = None + original_expected_type = expected_type + mismatch = False for item in items: if item is not None: item_type = item.dtype.base_dtype if not expected_type: expected_type = item_type - original_item_str = item.name if hasattr(item, 'name') else str(item) elif expected_type != item_type: - raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( - item.name if hasattr(item, 'name') else str(item), - item_type, expected_type, - (' as %s' % original_item_str) if original_item_str else '')) - return expected_type + mismatch = True + break + if mismatch: + # Loop back through and build up an informative error message (this is very + # slow, so we don't do it unless we found an error above). + expected_type = original_expected_type + original_item_str = None + for item in items: + if item is not None: + item_type = item.dtype.base_dtype + if not expected_type: + expected_type = item_type + original_item_str = item.name if hasattr(item, 'name') else str(item) + elif expected_type != item_type: + raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( + item.name if hasattr(item, 'name') else str(item), + item_type, expected_type, + (' as %s' % original_item_str) if original_item_str else '')) + return expected_type # Should be unreachable + else: + return expected_type @tf_export('assert_same_float_dtype') diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 07d4ff7b02..5f60dab6ac 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_logging_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops # go/tf-wildcard-import @@ -1433,6 +1434,8 @@ def ZerosLikeOutsideLoop(op, index): """Create zeros_like for the specified output of an op.""" val = op.outputs[index] if not util.IsSwitch(op): + if val.dtype == dtypes.resource: + return array_ops.zeros(gen_resource_variable_ops.variable_shape(val)) return array_ops.zeros_like(val, optimize=False) else: op_ctxt = op._get_control_flow_context() @@ -1441,6 +1444,10 @@ def ZerosLikeOutsideLoop(op, index): pred = op_ctxt.pred branch = op_ctxt.branch switch_val = switch(op.inputs[0], pred)[1 - branch] + if val.dtype == dtypes.resource: + with ops.control_dependencies([switch_val]): + return array_ops.zeros( + gen_resource_variable_ops.variable_shape(switch_val)) zeros_shape = array_ops.shape_internal(switch_val, optimize=False) # Ensure ops created within array_ops.zeros are dominated by switch in # cond context. diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py index 36eee5ce78..caceadf53a 100644 --- a/tensorflow/python/ops/distributions/bijector_impl.py +++ b/tensorflow/python/ops/distributions/bijector_impl.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -527,8 +528,6 @@ class Bijector(object): ValueError: If a member of `graph_parents` is not a `Tensor`. """ self._graph_parents = graph_parents or [] - forward_min_event_ndims = get_static_value(forward_min_event_ndims) - inverse_min_event_ndims = get_static_value(inverse_min_event_ndims) if forward_min_event_ndims is None and inverse_min_event_ndims is None: raise ValueError("Must specify at least one of `forward_min_event_ndims` " @@ -538,12 +537,23 @@ class Bijector(object): elif forward_min_event_ndims is None: forward_min_event_ndims = inverse_min_event_ndims + if not isinstance(forward_min_event_ndims, int): + raise TypeError("Expected forward_min_event_ndims to be of " + "type int, got {}".format( + type(forward_min_event_ndims).__name__)) + + if not isinstance(inverse_min_event_ndims, int): + raise TypeError("Expected inverse_min_event_ndims to be of " + "type int, got {}".format( + type(inverse_min_event_ndims).__name__)) + if forward_min_event_ndims < 0: raise ValueError("forward_min_event_ndims must be a non-negative " "integer.") if inverse_min_event_ndims < 0: raise ValueError("inverse_min_event_ndims must be a non-negative " "integer.") + self._forward_min_event_ndims = forward_min_event_ndims self._inverse_min_event_ndims = inverse_min_event_ndims self._is_constant_jacobian = is_constant_jacobian @@ -994,7 +1004,6 @@ class Bijector(object): def _reduce_jacobian_det_over_event( self, y, ildj, min_event_ndims, event_ndims): """Reduce jacobian over event_ndims - min_event_ndims.""" - assert_static(min_event_ndims) if not self.is_constant_jacobian: return math_ops.reduce_sum( @@ -1012,7 +1021,7 @@ class Bijector(object): axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. - event_ndims_ = get_static_value(event_ndims) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): @@ -1027,8 +1036,7 @@ class Bijector(object): def _get_event_reduce_dims(self, min_event_ndims, event_ndims): """Compute the reduction dimensions given event_ndims.""" - assert_static(min_event_ndims) - event_ndims_ = get_static_value(event_ndims, np.int32) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) if event_ndims_ is not None: return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)] @@ -1038,8 +1046,7 @@ class Bijector(object): def _check_valid_event_ndims(self, min_event_ndims, event_ndims): """Check whether event_ndims is atleast min_event_ndims.""" - assert_static(min_event_ndims) - event_ndims_ = get_static_value(event_ndims, np.int32) + event_ndims_ = self._maybe_get_event_ndims_statically(event_ndims) assertions = [] if event_ndims_ is not None: if min_event_ndims > event_ndims_: @@ -1051,21 +1058,15 @@ class Bijector(object): check_ops.assert_greater_equal(event_ndims, min_event_ndims)] return assertions + def _maybe_get_event_ndims_statically(self, event_ndims): + """Helper which returns tries to return an integer static value.""" + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) -def get_static_value(x, dtype=None): - """Helper which returns static value; casting when dtype is preferred.""" - if x is None: - return x - try: - x_ = tensor_util.constant_value(x) - except TypeError: - x_ = x - if x_ is None or dtype is None: - return x_ - return np.array(x_, dtype) - + if isinstance(event_ndims_, np.ndarray): + if (event_ndims_.dtype not in (np.int32, np.int64) or + len(event_ndims_.shape)): + raise ValueError("Expected a scalar integer, got {}".format( + event_ndims_)) + event_ndims_ = event_ndims_.tolist() -def assert_static(x): - """Helper which asserts that input arg is known statically.""" - if x is None or type(x) != type(get_static_value(x)): # pylint: disable=unidiomatic-typecheck - raise TypeError("Input must be known statically.") + return event_ndims_ diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 2e067eab45..3afa85fda0 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -162,6 +162,30 @@ def same_dynamic_shape(a, b): lambda: constant_op.constant(False)) +def maybe_get_static_value(x, dtype=None): + """Helper which tries to return a static value. + + Given `x`, extract it's value statically, optionally casting to a specific + dtype. If this is not possible, None is returned. + + Args: + x: `Tensor` for which to extract a value statically. + dtype: Optional dtype to cast to. + + Returns: + Statically inferred value if possible, otherwise None. + """ + if x is None: + return x + try: + x_ = tensor_util.constant_value(x) + except TypeError: + x_ = x + if x_ is None or dtype is None: + return x_ + return np.array(x_, dtype) + + def get_logits_and_probs(logits=None, probs=None, multidimensional=False, diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 1448151fef..069b5a4308 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -297,7 +297,8 @@ def _DefaultGradYs(grad_ys, def _IsTrainable(tensor): dtype = dtypes.as_dtype(tensor.dtype) return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, - dtypes.complex64, dtypes.complex128) + dtypes.complex64, dtypes.complex128, + dtypes.resource) def _IsBackpropagatable(tensor): @@ -417,6 +418,30 @@ def _MaybeCompile(scope, op, func, grad_fn): return grad_fn() +def _RaiseNoGradWrtInitialLoopValError(op, from_ops): + """Raises an error if we backprop through a loop var.""" + # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error + # message. + target_op = None + queue = collections.deque([op]) + visited = set() + while queue: + curr_op = queue.popleft() + if curr_op in visited: continue + visited.add(curr_op) + if curr_op in from_ops: + target_op = curr_op + break + queue.extend(t.op for t in curr_op.inputs) + assert target_op + raise ValueError( + "Cannot compute gradient inside while loop with respect to op '%s'. " + "We do not support taking the gradient wrt or through the initial value " + "of a loop variable. Gradients can be computed through loop invariants " + "or wrt the input parameters to the loop body." + % target_op.name) + + @tf_export("gradients") def gradients(ys, xs, @@ -629,6 +654,21 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, (op.name, op.type)) if loop_state: loop_state.EnterGradWhileContext(op, before=False) + + # NOTE(skyewm): We don't support computing gradients wrt a loop variable + # unless it's within the context of a single iteration (i.e. the + # gradient is wrt to the loop parameter in the body function, not wrt or + # through the initial value). This means if we're in a while loop + # context, we should never see a switch node from this context. + # pylint: disable=protected-access + if (control_flow_util.IsSwitch(op) and + op._control_flow_context is not None and + op._control_flow_context.IsWhileContext() and + op._control_flow_context == + ops.get_default_graph()._get_control_flow_context()): + _RaiseNoGradWrtInitialLoopValError(op, from_ops) + # pylint: enable=protected-access + if (grad_fn or is_func_call) and has_out_grads: # NOTE: If _AggregatedGrads didn't compute a value for the i'th # output, it means that the cost does not depend on output[i], diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 5e8b8822ef..e729950201 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -944,6 +944,21 @@ class CustomGradientTest(test_util.TensorFlowTestCase): # Smoke test to ensure numpy inputs are accepted F(x) + def testRVGradientsDynamicCond(self): + with self.test_session(): + alpha = resource_variable_ops.ResourceVariable( + np.random.random((1,)), + dtype="float32") + + conditional = array_ops.placeholder_with_default(True, shape=()) + output = control_flow_ops.cond( + conditional, lambda: alpha * 2, lambda: alpha * 3) + + g, = gradients_impl.gradients(output, alpha) + variables.global_variables_initializer().run() + self.assertAllEqual(g.eval(), [2.0]) + self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index d2f45ce37b..cc92da4fd7 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import contextlib +import weakref from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -395,69 +396,8 @@ class _GraphTensorArray(object): # pylint: enable=protected-access -# pylint: disable=protected-access -def _eager_write_no_copy(ta, index, value): - """Writes value into an _EagerTensorArray without creating a new TensorArray. - - Args: - ta: _EagerTensorArray into which to write value. - index: 0-D. int32 scalar with the index to write to. - value: N-D. Tensor of type `dtype`. The Tensor to write to this index. - - Raises: - errors_impl.AlreadyExistsError: attempting to overwrite an entry. - errors_impl.InvalidArgumentError: value dtype does not match `ta`'s dtype. - errors_impl.OutOfRangeError: `index` is out of bounds. - ValueError: shape of `value` is not consistent with inferred shape. - """ - - if isinstance(index, ops.EagerTensor): - index = index.numpy() - - if index < 0: - raise errors_impl.OutOfRangeError( - None, None, - "Writing to negative indices (index %d) is not allowed." % index) - - tensor_array = ta._tensor_array - size = len(tensor_array) - if index >= size: - if not ta._dynamic_size: - raise errors_impl.OutOfRangeError( - None, None, - "Tried to write to index %d but array is not resizeable and size " - "is: %d" % (index, size)) - tensor_array.extend([None for _ in range(index - size + 1)]) - - if not isinstance(value, ops.EagerTensor): - value = constant_op.constant(value) - - if ta._infer_shape: - if ta._element_shape is None: - ta._element_shape = value.shape - elif ta._element_shape != value.shape: - raise ValueError("Incompatible shape for value (%s), expected (%s)" % - (value.shape.as_list(), ta._element_shape.as_list())) - - if ta._dtype != value.dtype: - raise errors_impl.InvalidArgumentError( - None, None, - "TensorArray dtype is %s but Op is trying to write dtype %s" % - (ta._dtype.name, value.dtype.name)) - - if ta._tensor_array[index] is not None: - raise errors_impl.AlreadyExistsError( - None, None, - "Could not write to TensorArray index %d because it has already been " - "written to." % index) - - tensor_array[index] = value - -# pylint: enable=protected-access - - class _EagerTensorArray(object): - """Eager-mode implementation of TensorArray. + """Eager-compatible implementation of TensorArray. """ def __init__(self, @@ -472,7 +412,7 @@ class _EagerTensorArray(object): element_shape=None, colocate_with_first_write_call=True, name=None): - """Constructs an Eager mode TensorArray. + """Constructs a TensorArray compatible with eager execution. Args: dtype: (required) data type of the TensorArray. @@ -495,16 +435,19 @@ class _EagerTensorArray(object): ValueError: handle or flow are supplied, or if size is not supplied. """ - del (flow, tensor_array_name, name) # not meaningful in Eager + del (flow, tensor_array_name, name) # Unused. if handle is not None: - raise ValueError("TensorArray handles are not supported in Eager mode.") + raise ValueError("TensorArray handles are not supported when eager " + "execution is enabled.") if size is None: - raise ValueError("Size must be declared for TensorArrays in Eager mode.") + raise ValueError("Size must be declared for TensorArrays when eager " + "execution is enabled.") - # These attributes are not meaningful in Eager, but some library functions - # (e.g., those in control_flow_ops.py) access them to create new tensor - # arrays; as such, we define them for the sake of compatibility. + # These attributes are not meaningful when eager is enabled, but some + # library functions (e.g., those in control_flow_ops.py) access them to + # create new tensor arrays; as such, we define them for the sake of + # compatibility. self._handle = None # we assign a dummy value to _flow in case other code assumes it to be # a Tensor @@ -525,7 +468,7 @@ class _EagerTensorArray(object): @property def flow(self): - """Flows are not meaningful in Eager; this exists for compatibility.""" + """For compatibility; flows are not meaningful when eager is enabled.""" return self._flow @property @@ -534,42 +477,22 @@ class _EagerTensorArray(object): @property def handle(self): - """Handles are not meaningful in Eager; this exists for compatibility.""" + """For compatibility; handles are not meaningful when eager is enabled.""" return self._handle - def _identity_without_array(self): - """Returns a new TensorArray with the same properties as this Eager one. - - NB: Does not set the underlying _tensor_array attribute. - """ - ta = TensorArray( - dtype=self._dtype, - size=len(self._tensor_array), - dynamic_size=self._dynamic_size, - clear_after_read=self._clear_after_read, - handle=self._handle, - flow=self._flow, - infer_shape=self._infer_shape, - element_shape=self._element_shape, - colocate_with_first_write_call=self._colocate_with_first_write_call) - ta._implementation._previously_read_indices = self._previously_read_indices # pylint: disable=protected-access - return ta - def identity(self): """See TensorArray.""" - ta = self._identity_without_array() - ta._implementation._tensor_array = [t for t in self._tensor_array] # pylint: disable=protected-access - return ta + return self.parent() def grad(self, source, flow=None, name=None): raise NotImplementedError( - "TensorArray.grad is not supported in Eager mode; Eager's gradient " - "implementation does not use/need this function to compute gradients " - "of operations that use TensorArrays.") + "TensorArray.grad is not supported when executing eagerly; eager's " + "gradient implementation does not use/need this function to compute " + "gradients of operations that use TensorArrays.") def read(self, index, name=None): """See TensorArray.""" - del name # not meaningful in Eager mode + del name # not meaningful when executing eagerly. if isinstance(index, ops.EagerTensor): index = index.numpy() @@ -600,12 +523,58 @@ class _EagerTensorArray(object): self._previously_read_indices.append(index) return tensor + def _write(self, index, value): + """Writes `value` into index named by `index`. + + Args: + index: 0-D. int32 scalar with the index to write to. + value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`. + + Raises: + errors_impl.InvalidArgumentError: `value` dtype does not match dtype. + errors_impl.OutOfRangeError: `index` is out of bounds. + ValueError: shape of `value` is not consistent with inferred shape. + """ + + if isinstance(index, ops.EagerTensor): + index = index.numpy() + + if index < 0: + raise errors_impl.OutOfRangeError( + None, None, + "Writing to negative indices (index %d) is not allowed." % index) + + size = len(self._tensor_array) + if index >= size: + if not self._dynamic_size: + raise errors_impl.OutOfRangeError( + None, None, + "Tried to write to index %d but array is not resizeable and size " + "is: %d" % (index, size)) + self._tensor_array.extend([None for _ in range(index - size + 1)]) + + if not isinstance(value, ops.EagerTensor): + value = constant_op.constant(value) + + if self._infer_shape: + if self._element_shape is None: + self._element_shape = value.shape + elif self._element_shape != value.shape: + raise ValueError("Incompatible shape for value (%s), expected (%s)" % + (value.shape.as_list(), self._element_shape.as_list())) + + if self._dtype != value.dtype: + raise errors_impl.InvalidArgumentError( + None, None, + "TensorArray dtype is %s but Op is trying to write dtype %s" % + (self._dtype.name, value.dtype.name)) + self._tensor_array[index] = value + def write(self, index, value, name=None): """See TensorArray.""" - del name # not meaningful in Eager mode - ta = self.identity() - _eager_write_no_copy(ta._implementation, index, value) # pylint: disable=protected-access - return ta + del name # not meaningful when executing eagerly. + self._write(index, value) + return self.parent() def _maybe_zero(self, ix): val = self._tensor_array[ix] @@ -623,7 +592,7 @@ class _EagerTensorArray(object): def gather(self, indices, name=None): """See TensorArray.""" - del name # not meaningful in Eager mode + del name # not meaningful when executing eagerly. return array_ops.stack([self._maybe_zero(i) for i in indices.numpy()]) def concat(self, name=None): @@ -651,17 +620,15 @@ class _EagerTensorArray(object): raise ValueError( "Cannot unstack %d tensors into a TensorArray of static size %d" % (len(tensors), len(self._tensor_array))) - ta = self._identity_without_array() - ta._implementation._tensor_array = tensors # pylint: disable=protected-access - return ta + self._tensor_array = tensors + return self.parent() def scatter(self, indices, value, name=None): """See TensorArray.""" - del name # unused in Eager - ta = self.identity() + del name # not meaningful when executing eagerly. for index, val in zip(indices.numpy(), array_ops.unstack(value)): - _eager_write_no_copy(ta._implementation, index, val) # pylint: disable=protected-access - return ta + self._write(index, val) # pylint: disable=protected-access + return self.parent() def split(self, value, lengths, name=None): """See TensorArray.""" @@ -690,20 +657,17 @@ class _EagerTensorArray(object): "dynamically resizeable" % (len(self._tensor_array), lengths.shape[0])) else: - ta = self._identity_without_array() - tensor_array = array_ops.split(value, lengths, name=name) - ta._implementation._tensor_array = tensor_array # pylint: disable=protected-access - return ta + self._tensor_array = array_ops.split(value, lengths, name=name) + return self.parent() def size(self, name=None): """See TensorArray.""" - del name # not meaningful in Eager mode + del name # not meaningful when executing eagerly. return constant_op.constant(len(self._tensor_array)) def close(self, name=None): - del name # not meaningful in Eager mode + del name # not meaningful when executing eagerly. del self._tensor_array[:] - return # TensorArray is designed to hide an underlying implementation object @@ -789,6 +753,8 @@ class TensorArray(object): colocate_with_first_write_call=colocate_with_first_write_call, name=name) + self._implementation.parent = weakref.ref(self) + @property def flow(self): """The flow `Tensor` forcing ops leading to this TensorArray state.""" diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 01903ae596..8f1d5a099f 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -169,6 +169,25 @@ class SavedModelBuilder(object): raise TypeError("main_op needs to be an Operation: %r" % main_op) ops.add_to_collection(constants.MAIN_OP_KEY, main_op) + def _add_train_op(self, train_op): + """Add train op to the SavedModel. + + Note that this functionality is in development, and liable to be + moved elsewhere. + + Args: + train_op: Op or group of ops that are used for training. These are + stored as a collection with key TRAIN_OP_KEY, but not executed. + + Raises: + TypeError if Train op is not of type `Operation`. + """ + if train_op is not None: + if (not isinstance(train_op, ops.Tensor) and + not isinstance(train_op, ops.Operation)): + raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) + ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) + def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): """Tags the meta graph def and adds it to the SavedModel. @@ -239,6 +258,20 @@ class SavedModelBuilder(object): for outputs_key in outputs: self._validate_tensor_info(outputs[outputs_key]) + def _add_collections( + self, assets_collection, legacy_init_op, main_op, train_op): + """Add asset and op collections to be saved.""" + # Save asset files and write them to disk, if any. + self._save_and_write_assets(assets_collection) + + if main_op is None: + # Add legacy init op to the SavedModel. + self._maybe_add_legacy_init_op(legacy_init_op) + else: + self._add_main_op(main_op) + + self._add_train_op(train_op) + def add_meta_graph(self, tags, signature_def_map=None, @@ -286,14 +319,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) - - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Initialize a saver to generate a sharded output for all saveables in the # current scope. @@ -352,6 +379,7 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + """ # pylint: enable=line-too-long if self._has_saved_variables: @@ -363,8 +391,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( @@ -377,12 +405,6 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 34206c6f6d..61c6ffbd0d 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op" tf_export("saved_model.constants.MAIN_OP_KEY").export_constant( __name__, "MAIN_OP_KEY") +# CollectionDef key for the SavedModel train op. +# Not exported while export_all_saved_models is in contrib. +TRAIN_OP_KEY = "saved_model_train_op" + # Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant( @@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( VARIABLES_FILENAME = "variables" tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( __name__, "VARIABLES_FILENAME") + + diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 804255375e..a4d994fd43 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], legacy_init_op=legacy_init_op) + def testTrainOp(self): + export_dir = self._get_export_dir("test_train_op") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = state_ops.assign_add(v1, v2) + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(3, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + def testTrainOpGroup(self): + export_dir = self._get_export_dir("test_train_op_group") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = control_flow_ops.group() + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(1, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation) + + def testTrainOpAfterVariables(self): + export_dir = self._get_export_dir("test_train_op_after_variables") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["pre_foo"]) + + train_op = state_ops.assign_add(v1, v2) + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph(["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["pre_foo"], export_dir) + self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) + def testMultipleAssets(self): export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 819f351291..99007a9634 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant( __name__, "REGRESS_OUTPUTS") ################################################################################ +# Train/Eval API constants. +# Not exported while export_all_saved_models is in contrib. + +SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training" + +SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval" diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index ea0f52f17e..27d6b70e9d 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def # pylint: enable=unused-import del absolute_import diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index d033159188..f8ad788f77 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs): return signature_def +def supervised_train_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def supervised_eval_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def _supervised_signature_def( + method_name, inputs, loss=None, predictions=None, + metrics=None): + """Creates a signature for training and eval data. + + This function produces signatures that describe the inputs and outputs + of a supervised process, such as training or evaluation, that + results in loss, metrics, and the like. Note that this function only requires + inputs to be not None. + + Args: + method_name: Method name of the SignatureDef as a string. + inputs: dict of string to `Tensor`. + loss: dict of string to `Tensor` representing computed loss. + predictions: dict of string to `Tensor` representing the output predictions. + metrics: dict of string to `Tensor` representing metric ops. + + Returns: + A train- or eval-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + + signature_outputs = {} + for output_set in (loss, predictions, metrics): + if output_set is not None: + sig_out = {key: utils.build_tensor_info(tensor) + for key, tensor in output_set.items()} + signature_outputs.update(sig_out) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, method_name) + + return signature_def + + @tf_export('saved_model.signature_def_utils.is_valid_signature') def is_valid_signature(signature_def): """Determine whether a SignatureDef can be served by TensorFlow Serving.""" diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index b2bd14db8c..ebc5450633 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + def testTrainSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDef(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100.0, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + signature_def = fn_to_test(inputs, loss, predictions, metrics) + + self.assertEqual(method_name, signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = (signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = (signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(4, len(signature_def.outputs)) + self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) + self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) + + self.assertEqual("classes:0", signature_def.outputs["classes"].name) + self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) + + self.assertEqual( + "metrics_val:0", signature_def.outputs["metrics/value"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + self.assertEqual( + "metrics_op:0", signature_def.outputs["metrics/update_op"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + def testTrainSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + with self.assertRaises(ValueError): + signature_def = fn_to_test( + {}, loss=loss, predictions=predictions, metrics=metrics) + + signature_def = fn_to_test(inputs, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(1, len(signature_def.outputs)) + + signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(3, len(signature_def.outputs)) + def testGetShapeAndTypes(self): inputs = { "input-1": constant_op.constant(["a", "b"]), diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 5a797da791..c82154e7b9 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -32,6 +32,9 @@ TRAINING = "train" tf_export("saved_model.tag_constants.TRAINING").export_constant( __name__, "TRAINING") +# Tag for the `eval` graph. Not exported while the export logic is in contrib. +EVAL = "eval" + # Tag for the `gpu` graph. GPU = "gpu" tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") @@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") # Tag for the `tpu` graph. TPU = "tpu" tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU") + + diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 05afd37ccd..d00312a1f3 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -659,6 +659,31 @@ class CheckpointableBase(object): return {} +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + class Checkpointable(CheckpointableBase): """Manages dependencies on other objects. @@ -691,8 +716,11 @@ class Checkpointable(CheckpointableBase): """Support self.foo = checkpointable syntax.""" # Perform the attribute assignment, and potentially call other __setattr__ # overrides such as that for tf.keras.Model. + no_dependency = isinstance(value, NoDependency) + if no_dependency: + value = value.value super(Checkpointable, self).__setattr__(name, value) - if isinstance(value, CheckpointableBase): + if not no_dependency and isinstance(value, CheckpointableBase): self._track_checkpointable( value, name=name, # Allow the user to switch the Checkpointable which is tracked by this diff --git a/tensorflow/python/training/checkpointable_test.py b/tensorflow/python/training/checkpointable_test.py index e79acb4975..85802cb661 100644 --- a/tensorflow/python/training/checkpointable_test.py +++ b/tensorflow/python/training/checkpointable_test.py @@ -34,6 +34,16 @@ class InterfaceTests(test.TestCase): root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + def testNoDependency(self): + root = checkpointable.Checkpointable() + hasdep = checkpointable.Checkpointable() + root.hasdep = hasdep + nodep = checkpointable.Checkpointable() + root.nodep = checkpointable.NoDependency(nodep) + self.assertEqual(1, len(root._checkpoint_dependencies)) + self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) + self.assertIs(root.hasdep, hasdep) + self.assertIs(root.nodep, nodep) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index cf4112ff99..f2a2b411fd 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -1044,8 +1044,11 @@ class Checkpoint(checkpointable_lib.Checkpointable): if self._save_counter is None: # Initialized to 0 and incremented before saving. with ops.device("/cpu:0"): - self._save_counter = add_variable( - self, name="save_counter", initializer=0, dtype=dtypes.int64) + # add_variable creates a dependency named "save_counter"; NoDependency + # prevents creating a second dependency named "_save_counter". + self._save_counter = checkpointable_lib.NoDependency( + add_variable(self, name="save_counter", initializer=0, + dtype=dtypes.int64)) @property def save_counter(self): diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index c0d5ea36dd..ab8b37bb65 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -357,14 +357,14 @@ class DistributionStrategy(object): on different slices of the input data. This is in contrast to _model parallelism_ where we divide up a single copy of a model across multiple devices. - Note: for now we only support data parallelism at this time, but + Note: we only support data parallelism for now, but hope to add support for model parallelism in the future. * A _tower_ is one copy of the model, running on one slice of the input data. - * _Synchronous_, or more commonly _sync_, training is when the + * _Synchronous_, or more commonly _sync_, training is where the updates from each tower are aggregated together before updating the model variables. This is in contrast to _asynchronous_, or - _async_ training where each tower updates the model variables + _async_ training, where each tower updates the model variables independently. * Furthermore you might run your computation on multiple devices on one machine (or "host"), or on multiple machines/hosts. @@ -386,11 +386,11 @@ class DistributionStrategy(object): * Reductions and Allreduce: A _reduction_ is some method of aggregating multiple values into one value, like "sum" or "mean". If doing sync training, we will perform a reduction on the - gradients to a parameter from each tower before applying the + gradients to a parameter from all towers before applying the update. Allreduce is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices. - * In the future we will have support for TensorFlows' partitioned + * In the future we will have support for TensorFlow's partitioned variables, where a single variable is split across multiple devices. @@ -419,9 +419,9 @@ class DistributionStrategy(object): `tower_fn` can use the `get_tower_context()` API to get enhanced behavior in this case. - You can also create an initializable iterator instead of one shot iterator. - In that case, you will need to ensure that you initialize the iterator - before calling get_next. + You can also create an initializable iterator instead of a one-shot + iterator. In that case, you will need to ensure that you initialize the + iterator before calling get_next. ``` iterator = my_distribution.distribute_dataset( dataset).make_initializable_iterator()) @@ -816,6 +816,7 @@ class DistributionStrategy(object): # TODO(josh11b): Return an unwrapped value if colocate_with is a # single device. _require_cross_tower_context(self) + assert method_string in ("sum", "mean") return self._reduce(method_string, value, destinations) def _reduce(self, method_string, value, destinations): diff --git a/tensorflow/stream_executor/cuda/cuda_activation.cc b/tensorflow/stream_executor/cuda/cuda_activation.cc index cf6b9e2c6e..02371c3c3a 100644 --- a/tensorflow/stream_executor/cuda/cuda_activation.cc +++ b/tensorflow/stream_executor/cuda/cuda_activation.cc @@ -38,5 +38,11 @@ ScopedActivateExecutorContext::~ScopedActivateExecutorContext() { delete static_cast<ScopedActivateContext *>(driver_scoped_activate_context_); } +ScopedActivateExecutorContext::ScopedActivateExecutorContext( + ScopedActivateExecutorContext &&other) + : driver_scoped_activate_context_(other.driver_scoped_activate_context_) { + other.driver_scoped_activate_context_ = nullptr; +} + } // namespace cuda } // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_activation.h b/tensorflow/stream_executor/cuda/cuda_activation.h index 04ffaef364..ef9807820f 100644 --- a/tensorflow/stream_executor/cuda/cuda_activation.h +++ b/tensorflow/stream_executor/cuda/cuda_activation.h @@ -44,10 +44,11 @@ class ScopedActivateExecutorContext { // fatal failure if it is not CUDA inside. explicit ScopedActivateExecutorContext(StreamExecutor* stream_exec); + ScopedActivateExecutorContext(ScopedActivateExecutorContext&& other); + ~ScopedActivateExecutorContext(); private: - // The cuda.h-using datatype that we wrap. ScopedActivateContext* driver_scoped_activate_context_; diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 773cac2c40..af78efe81d 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -46,8 +46,20 @@ limitations under the License. #include "cuda/include/cudnn.h" // clang-format on +namespace stream_executor { +namespace cuda { + +PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin); + namespace { +// TODO(csigg): remove dnn namespace qualifier from the RNN code below. +using ::stream_executor::dnn::BatchDescriptor; +using ::stream_executor::dnn::ConvolutionDescriptor; +using ::stream_executor::dnn::FilterDescriptor; +using ::stream_executor::dnn::NormalizeDescriptor; +using ::stream_executor::dnn::PoolingDescriptor; + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template <typename WideT, typename NarrowT> @@ -58,20 +70,6 @@ NarrowT CheckedNarrowing(const WideT& wide) { return narrow; } -} // namespace - -namespace stream_executor { - -using dnn::BatchDescriptor; -using dnn::FilterDescriptor; -using dnn::ConvolutionDescriptor; -using dnn::PoolingDescriptor; -using dnn::NormalizeDescriptor; - -namespace cuda { - -PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin); - string ToString(cudnnStatus_t status) { switch (status) { case CUDNN_STATUS_SUCCESS: @@ -136,226 +134,82 @@ cudnnDataType_t GetCudnnDataType<Eigen::half>() { return CUDNN_DATA_HALF; } -namespace wrap { - -static port::ThreadPool* InitCudnnThreadpool() { - port::ThreadPool* cudnn_threadpool_; - port::ThreadOptions options; - // TBD(keveman): Conservatively setting the stack size and guard size to 2MB, - // until we can get some guarantees from NVIDIA on the minimum stack space - // they will work with. - options.stack_size = 2 * 1024 * 1024; - options.guard_size = 2 * 1024 * 1024; - cudnn_threadpool_ = new port::ThreadPool(port::Env::Default(), options, - "cudnn_threadpool", 1); - CHECK(cudnn_threadpool_); - return cudnn_threadpool_; -} - -static mutex cudnn_threadpool_mu(LINKER_INITIALIZED); -static port::ThreadPool* GetCudaThreadpool() { - mutex_lock lock(cudnn_threadpool_mu); - static port::ThreadPool* cudnn_threadpool = InitCudnnThreadpool(); - return cudnn_threadpool; -} - -#define STREAM_EXECUTOR_CUDNN_WRAP(__name) \ - struct WrapperShim__##__name { \ - template <typename... Args> \ - cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \ - cuda::ScopedActivateExecutorContext sac{parent}; \ - cudnnStatus_t retval = ::__name(args...); \ - return retval; \ - } \ - } __name; - -#define STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \ - struct WrapperShim__##__name { \ - template <typename... Args> \ - cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \ - SHARED_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) { \ - CHECK_NOTNULL(s); \ - CHECK_EQ(s, dnn->GetCurrentDnnStream()) \ - << "Stream is not set correctly!"; \ - cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()}; \ - cudnnStatus_t retval = ::__name(args...); \ - return retval; \ - } \ - } __name; - -// Handles cudnnSetStream differently in order to add debug information. -struct WrapperShim__cudnnSetStream { - cudnnStatus_t operator()(CudnnSupport* dnn, Stream* stream, - cudnnHandle_t handle) - EXCLUSIVE_LOCKS_REQUIRED(dnn->dnn_handle_mutex_) { - dnn->SetCurrentDnnStream(stream); - cuda::ScopedActivateExecutorContext sac{dnn->GetParentExecutor()}; - cudnnStatus_t retval = ::cudnnSetStream(handle, AsCUDAStreamValue(stream)); - return retval; - } -} cudnnSetStream; - -// clang-format off -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnGetConvolutionNdForwardOutputDim) \ - __macro(cudnnGetConvolutionForwardAlgorithm) \ - __macro(cudnnCreateTensorDescriptor) \ - __macro(cudnnDestroyTensorDescriptor) \ - __macro(cudnnCreateFilterDescriptor) \ - __macro(cudnnSetPoolingNdDescriptor) \ - __macro(cudnnSetLRNDescriptor) \ - __macro(cudnnDestroyFilterDescriptor) \ - __macro(cudnnCreateConvolutionDescriptor) \ - __macro(cudnnCreatePoolingDescriptor) \ - __macro(cudnnDestroyPoolingDescriptor) \ - __macro(cudnnCreateLRNDescriptor) \ - __macro(cudnnDestroyLRNDescriptor) \ - __macro(cudnnDestroyConvolutionDescriptor) \ - __macro(cudnnCreate) \ - __macro(cudnnDestroy) \ - __macro(cudnnGetConvolutionForwardWorkspaceSize) \ - __macro(cudnnSetConvolutionNdDescriptor) \ - __macro(cudnnSetTensor4dDescriptor) \ - __macro(cudnnSetTensorNdDescriptor) \ - __macro(cudnnSetFilterNdDescriptor) - -// clang-format on -CUDNN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH - -// clang-format off -#define CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(__macro) \ - __macro(cudnnBatchNormalizationBackward) \ - __macro(cudnnBatchNormalizationForwardInference) \ - __macro(cudnnBatchNormalizationForwardTraining) \ - __macro(cudnnActivationForward) \ - __macro(cudnnConvolutionForward) \ - __macro(cudnnConvolutionBackwardBias) \ - __macro(cudnnTransformTensor) \ - __macro(cudnnPoolingForward) \ - __macro(cudnnPoolingBackward) \ - __macro(cudnnLRNCrossChannelForward) \ - __macro(cudnnLRNCrossChannelBackward) \ - __macro(cudnnAddTensor) \ - __macro(cudnnConvolutionBackwardData) \ - __macro(cudnnConvolutionBackwardFilter) - -// clang-format on -CUDNN_DNN_ROUTINE_EACH_WITH_STREAM( - STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) -#undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM - -// APIs available after R3: -#if CUDNN_VERSION >= 3000 -#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \ - __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ - __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ - __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ - __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) -CUDNN_DNN_ROUTINE_EACH_AFTER_R3(STREAM_EXECUTOR_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 -#endif - -// APIs in R3 but not in R5 -// clang-format off -#if CUDNN_VERSION >= 3000 && CUDNN_VERSION < 5000 -#define CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(__macro) \ - __macro(cudnnAddTensor_v3) \ - __macro(cudnnConvolutionBackwardData_v3) \ - __macro(cudnnConvolutionBackwardFilter_v3) -// clang-format on - -CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM( - STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) -#undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM -#endif - -// APIs in R5 -// clang-format off -#if CUDNN_VERSION >= 5000 -#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \ - __macro(cudnnCreateActivationDescriptor) \ - __macro(cudnnSetActivationDescriptor) \ - __macro(cudnnGetActivationDescriptor) \ - __macro(cudnnDestroyActivationDescriptor) \ - __macro(cudnnCreateDropoutDescriptor) \ - __macro(cudnnDestroyDropoutDescriptor) \ - __macro(cudnnSetDropoutDescriptor) \ - __macro(cudnnDropoutGetStatesSize) \ - __macro(cudnnCreateRNNDescriptor) \ - __macro(cudnnDestroyRNNDescriptor) \ - __macro(cudnnGetRNNParamsSize) \ - __macro(cudnnGetRNNWorkspaceSize) \ - __macro(cudnnGetRNNTrainingReserveSize) \ - __macro(cudnnGetRNNLinLayerMatrixParams) \ - __macro(cudnnGetRNNLinLayerBiasParams) \ - __macro(cudnnSetRNNDescriptor) \ - __macro(cudnnGetFilterNdDescriptor) - -// clang-format on -CUDNN_DNN_ROUTINE_EACH_R5(STREAM_EXECUTOR_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_R5 - -// clang-format off -#define CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(__macro) \ - __macro(cudnnRNNForwardInference) \ - __macro(cudnnRNNForwardTraining) \ - __macro(cudnnRNNBackwardData) \ - __macro(cudnnRNNBackwardWeights) +// RAII wrapper for all calls to cuDNN with a cuDNN handle argument. +// +// See CudnnAccess::GetHandle() for details. +class CudnnHandle { + public: + // Takes ownership of the executor context and the lock to access cuDNN + // using handle. + CudnnHandle(cuda::ScopedActivateExecutorContext context, mutex_lock lock, + cudnnHandle_t handle) + : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {} -// clang-format on -CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM( - STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) -#undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM -#endif + // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep + // a copy. + cudnnHandle_t handle() const { return handle_; } -// APIs in R6 -// clang-format off -#if CUDNN_VERSION >= 6000 -#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \ - __macro(cudnnSetRNNDescriptor_v6) \ - __macro(cudnnCreatePersistentRNNPlan) \ - __macro(cudnnDestroyPersistentRNNPlan) \ - __macro(cudnnSetPersistentRNNPlan) + private: + cuda::ScopedActivateExecutorContext context_; + mutex_lock lock_; + cudnnHandle_t handle_; // Not owned. +}; -// clang-format on -CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_R6 +} // namespace -// clang-format off -#define CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(__macro) \ - __macro(cudnnConvolutionBiasActivationForward) +// Wraps a cuDNN handle and provides access to it through CudnnHandle instances, +// which also locks a mutex, acquires the CUDA context, and sets the stream +// that cuDNN should use to enqueue any work. +// +// Note: CudnnSupport::cudnn_ should be the only instantiation of this class. +class CudnnAccess { + public: + // Takes ownership of the handle. + explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {} -// clang-format on -CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM( - STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) -#undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM -#endif + ~CudnnAccess() { + mutex_lock lock(mutex_); + cudnnDestroy(handle_); + } -// APIs in R7 -// clang-format off -#if CUDNN_VERSION >= 7000 -#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ - __macro(cudnnSetConvolutionMathType) \ - __macro(cudnnSetRNNMatrixMathType) \ - __macro(cudnnSetConvolutionGroupCount) \ - __macro(cudnnGetConvolutionGroupCount) + // Creates a CudnnHandle instance for stream. + // + // cuDNN API calls using the same handle instance need to be serialized across + // threads. This is guaranteed by CudnnHandle instances locking the mutex + // owned by this class. + // + // Most cuDNN APIs taking a handle perform work on a CUDA stream. The + // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to + // use the provided stream. + // + // The stream argument may be null, which translates to the legacy default + // stream. See + // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html. + // The legacy default stream synchronizes with all other streams and it is + // therefore a bad idea (performance wise) to call any cuDNN APIs that + // enqueue work in the stream. + CudnnHandle GetHandle(CUDAExecutor* executor, Stream* stream) { + mutex_lock lock(mutex_); + cuda::ScopedActivateExecutorContext context(executor); + CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy; + auto status = cudnnSetStream(handle_, cu_stream); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream."; + using my_mutex_lock = mutex_lock; + return CudnnHandle(std::move(context), std::move(lock), handle_); + } -// clang-format on -CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH_R7 -#endif + private: + // Guards the enqueueing of cuDNN operations via the handle_ below. + mutex mutex_; -} // namespace wrap + // cuDNN library handle. + cudnnHandle_t handle_ GUARDED_BY(mutex_); // Owned. +}; namespace { cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); -cudnnHandle_t ToHandle(void* opaque_handle) { - return static_cast<cudnnHandle_t>(opaque_handle); -} - cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) { cudnnConvolutionFwdAlgo_t algo = cudnnConvolutionFwdAlgo_t(algorithm.algo_id()); @@ -432,7 +286,7 @@ port::Status GetCudnnProperty(libraryPropertyType type, int* value) { port::StrCat("cudnnGetProperty failed for type: ", ToString(type), " with status: ", ToString(status)); LOG(ERROR) << error; - return port::Status{port::error::INTERNAL, error}; + return port::Status(port::error::INTERNAL, error); } return port::Status::OK(); } @@ -471,19 +325,11 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) { } // namespace -CudnnSupport::CudnnSupport(CUDAExecutor* parent) - : parent_(parent), dnn_handle_(nullptr), current_dnn_stream_(nullptr) {} - -CudnnSupport::~CudnnSupport() { - auto status = wrap::cudnnDestroy(parent_, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn handle: " << ToString(status); - } -} +CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {} port::Status CudnnSupport::Init() { - auto status = wrap::cudnnCreate( - parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_)); + cudnnHandle_t cudnn_handle = nullptr; + auto status = cudnnCreate(&cudnn_handle); if (status == CUDNN_STATUS_SUCCESS) { CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); @@ -499,9 +345,10 @@ port::Status CudnnSupport::Init() { "from sources, make sure the library loaded at runtime is compatible " "with the version specified during compile configuration."); LOG(ERROR) << error; - return port::Status{port::error::INTERNAL, error}; + return port::Status(port::error::INTERNAL, error); } + cudnn_.reset(new CudnnAccess(cudnn_handle)); return port::Status::OK(); } @@ -525,9 +372,9 @@ port::Status CudnnSupport::Init() { } } - return port::Status{port::error::INTERNAL, + return port::Status(port::error::INTERNAL, port::StrCat("cudnn library could not create a handle: ", - ToString(status))}; + ToString(status))); } port::StatusOr<perftools::gputools::dnn::VersionInfo> @@ -538,14 +385,15 @@ CudnnSupport::GetVersion() { version.major_version, version.minor_version, version.patch_level); } +namespace { + // Turns a BatchDescriptor structure into a cudnn tensor handle within a scope. class ScopedTensorDescriptor { public: - ScopedTensorDescriptor(CUDAExecutor* parent, - const BatchDescriptor& batch_descriptor, + ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor, cudnnDataType_t elem_type) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent_, &handle_); + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn tensor descriptor: " << ToString(status); @@ -568,8 +416,8 @@ class ScopedTensorDescriptor { &CheckedNarrowing<int64, int>); std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), &CheckedNarrowing<int64, int>); - status = wrap::cudnnSetTensorNdDescriptor( - parent_, handle_, elem_type, nd, dims.data(), strides.data()); + status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(), + strides.data()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not convert BatchDescriptor " @@ -579,8 +427,8 @@ class ScopedTensorDescriptor { } break; #if CUDNN_VERSION >= 6000 case dnn::DataLayout::kBatchDepthYX4: { - status = wrap::cudnnSetTensor4dDescriptor( - parent_, handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type, + status = cudnnSetTensor4dDescriptor( + handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type, batch_descriptor.count(), batch_descriptor.feature_map_count(), batch_descriptor.height(), batch_descriptor.width()); if (status != CUDNN_STATUS_SUCCESS) { @@ -598,7 +446,7 @@ class ScopedTensorDescriptor { } ~ScopedTensorDescriptor() { - cudnnStatus_t status = wrap::cudnnDestroyTensorDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn tensor descriptor: " << ToString(status); @@ -608,7 +456,6 @@ class ScopedTensorDescriptor { cudnnTensorDescriptor_t handle() const { return handle_; } private: - CUDAExecutor* parent_; // Parent executor. Not owned. cudnnTensorDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor); @@ -617,12 +464,10 @@ class ScopedTensorDescriptor { // Turns a FilterDescriptor structure into a cudnn filter handle within a scope. class ScopedFilterDescriptor { public: - ScopedFilterDescriptor(CUDAExecutor* parent, - const FilterDescriptor& filter_descriptor, - const BatchDescriptor& batch_descriptor, + ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor, cudnnDataType_t elem_type) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = wrap::cudnnCreateFilterDescriptor(parent_, &handle_); + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn filter descriptor: " << ToString(status); @@ -656,11 +501,11 @@ class ScopedFilterDescriptor { const auto& spatial_dims = filter_descriptor.input_filter_dims(); std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2); - status = wrap::cudnnSetFilterNdDescriptor(parent_, handle_, elem_type, + status = cudnnSetFilterNdDescriptor(handle_, elem_type, #if CUDNN_VERSION >= 5000 - format, + format, #endif - dims.size(), dims.data()); + dims.size(), dims.data()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn filter descriptor: " << ToString(status); @@ -668,7 +513,7 @@ class ScopedFilterDescriptor { } ~ScopedFilterDescriptor() { - cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn filter descriptor: " << ToString(status); @@ -678,11 +523,7 @@ class ScopedFilterDescriptor { cudnnFilterDescriptor_t handle() const { return handle_; } private: - // Parent executor object. Not owned. - CUDAExecutor* parent_; - - // cudnn filter descriptor this object creates. Owned. - cudnnFilterDescriptor_t handle_; + cudnnFilterDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor); }; @@ -736,11 +577,10 @@ static bool BatchnormSpatialPersistentEnabled() { class ScopedConvolutionDescriptor { public: ScopedConvolutionDescriptor( - CUDAExecutor* parent, const ConvolutionDescriptor& convolution_descriptor, + const ConvolutionDescriptor& convolution_descriptor, cudnnDataType_t data_type) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = - wrap::cudnnCreateConvolutionDescriptor(parent_, &handle_); + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn convolution descriptor: " << ToString(status); @@ -766,9 +606,9 @@ class ScopedConvolutionDescriptor { std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(), &CheckedNarrowing<int64, int>); - status = wrap::cudnnSetConvolutionNdDescriptor( - parent_, handle_, convolution_descriptor.ndims(), padding.data(), - strides.data(), dilations.data(), + status = cudnnSetConvolutionNdDescriptor( + handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), + dilations.data(), // NOTE(keveman): cuDNN supports convolution and cross correlation. // However, almost all the use cases do cross correlation, so just // hard coding it here. @@ -785,8 +625,8 @@ class ScopedConvolutionDescriptor { #if CUDNN_MAJOR >= 7 VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); - status = wrap::cudnnSetConvolutionGroupCount( - parent_, handle_, convolution_descriptor.group_count()); + status = cudnnSetConvolutionGroupCount( + handle_, convolution_descriptor.group_count()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn convolution group count: " << ToString(status); @@ -802,8 +642,7 @@ class ScopedConvolutionDescriptor { cudnnMathType_t math_type = (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); if (TensorOpMathEnabled()) { - cudnnStatus_t status = - wrap::cudnnSetConvolutionMathType(parent_, handle_, math_type); + cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn convolution math type: " << ToString(status); @@ -813,8 +652,7 @@ class ScopedConvolutionDescriptor { } ~ScopedConvolutionDescriptor() { - cudnnStatus_t status = - wrap::cudnnDestroyConvolutionDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn convolution descriptor: " << ToString(status); @@ -824,7 +662,6 @@ class ScopedConvolutionDescriptor { cudnnConvolutionDescriptor_t handle() const { return handle_; } private: - CUDAExecutor* parent_; // Parent executor. Not owned. cudnnConvolutionDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor); @@ -834,11 +671,9 @@ class ScopedConvolutionDescriptor { // within a scope. class ScopedPoolingDescriptor { public: - ScopedPoolingDescriptor(CUDAExecutor* parent, - const PoolingDescriptor& pooling_descriptor) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = - wrap::cudnnCreatePoolingDescriptor(parent_, &handle_); + explicit ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor) + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn pooling descriptor: " << ToString(status); @@ -858,8 +693,8 @@ class ScopedPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing<int64, int>); bool propagate_nans = pooling_descriptor.propagate_nans(); - status = wrap::cudnnSetPoolingNdDescriptor( - parent_, handle_, + status = cudnnSetPoolingNdDescriptor( + handle_, (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), @@ -873,8 +708,7 @@ class ScopedPoolingDescriptor { } } ~ScopedPoolingDescriptor() { - cudnnStatus_t status = - wrap::cudnnDestroyPoolingDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn pooling descriptor: " << ToString(status); @@ -884,7 +718,6 @@ class ScopedPoolingDescriptor { cudnnPoolingDescriptor_t handle() const { return handle_; } private: - CUDAExecutor* parent_; // Parent executor. Not owned. cudnnPoolingDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor); @@ -893,10 +726,10 @@ class ScopedPoolingDescriptor { // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle. class ScopedNormalizeDescriptor { public: - ScopedNormalizeDescriptor(CUDAExecutor* parent, - const NormalizeDescriptor& normalize_descriptor) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = wrap::cudnnCreateLRNDescriptor(parent_, &handle_); + explicit ScopedNormalizeDescriptor( + const NormalizeDescriptor& normalize_descriptor) + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn LRN descriptor: " << ToString(status); @@ -922,15 +755,14 @@ class ScopedNormalizeDescriptor { double lrnBeta = normalize_descriptor.beta(); double lrnK = normalize_descriptor.bias(); - status = wrap::cudnnSetLRNDescriptor(parent_, handle_, lrnN, lrnAlpha, - lrnBeta, lrnK); + status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status); } } ~ScopedNormalizeDescriptor() { - cudnnStatus_t status = wrap::cudnnDestroyLRNDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn LRN descriptor: " << ToString(status); @@ -940,7 +772,6 @@ class ScopedNormalizeDescriptor { cudnnLRNDescriptor_t handle() const { return handle_; } private: - CUDAExecutor* parent_; // Parent executor. Not owned. cudnnLRNDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor); @@ -951,13 +782,11 @@ class ScopedNormalizeDescriptor { // descriptor handle within a scope. class ScopedActivationDescriptor { public: - ScopedActivationDescriptor(CUDAExecutor* parent, - dnn::ActivationMode activation_mode, + ScopedActivationDescriptor(dnn::ActivationMode activation_mode, cudnnNanPropagation_t nan_propagation, double value_max) - : parent_(parent), handle_(nullptr) { - cudnnStatus_t status = - wrap::cudnnCreateActivationDescriptor(parent_, &handle_); + : handle_(nullptr) { + cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not create cudnn activation descriptor: " << ToString(status); @@ -988,8 +817,8 @@ class ScopedActivationDescriptor { << static_cast<int>(activation_mode); } - status = wrap::cudnnSetActivationDescriptor(parent_, handle_, mode, - nan_propagation, relu_ceiling); + status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation, + relu_ceiling); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "could not set cudnn activation descriptor: " << ToString(status); @@ -997,8 +826,7 @@ class ScopedActivationDescriptor { } ~ScopedActivationDescriptor() { - cudnnStatus_t status = - wrap::cudnnDestroyActivationDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not destroy cudnn activation descriptor: " << ToString(status); @@ -1008,14 +836,12 @@ class ScopedActivationDescriptor { cudnnActivationDescriptor_t handle() const { return handle_; } private: - CUDAExecutor* parent_; // Parent executor. Not owned. cudnnActivationDescriptor_t handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; #endif -namespace { cudnnDataType_t ToCudnnDataType( dnn::DataType data_type, dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) { @@ -1090,8 +916,6 @@ class MixinBase : public Base {}; template <> class MixinBase<void> {}; -} // namespace - #if CUDNN_VERSION >= 5000 #define CUDNN_RETURN_IF_FAIL(STATUS, ...) \ @@ -1102,6 +926,7 @@ class MixinBase<void> {}; return; \ } +// TODO(csigg): Remove inheritance for code reuse. template <typename Base> class CudnnDescriptorCommon : public MixinBase<Base> { public: @@ -1115,12 +940,11 @@ class CudnnDescriptorCommon : public MixinBase<Base> { class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { public: - CudnnDropoutDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, - float dropout, uint64 seed, + CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed, ScratchAllocator* state_allocator) - : parent_(parent), handle_(nullptr) { + : handle_(nullptr) { cudnnStatus_t status; - status = wrap::cudnnCreateDropoutDescriptor(parent_, &handle_); + status = cudnnCreateDropoutDescriptor(&handle_); CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor"); if (dropout == 0.f) { @@ -1130,8 +954,7 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { DeviceMemory<uint8> state_memory; if (state_allocator) { size_t state_sizes_in_bytes = 0; - status = wrap::cudnnDropoutGetStatesSize(parent_, cudnn_handle, - &state_sizes_in_bytes); + status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes); CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes"); auto allocated = @@ -1146,9 +969,9 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { return; } } - status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle, - dropout, state_memory.opaque(), - state_memory.size(), seed); + status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout, + state_memory.opaque(), + state_memory.size(), seed); CUDNN_RETURN_IF_FAIL( status, port::StrCat( "Failed to set dropout descriptor with state memory size: ", @@ -1156,11 +979,9 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { } ~CudnnDropoutDescriptor() { - if (handle_) { - cudnnStatus_t status = - wrap::cudnnDestroyDropoutDescriptor(parent_, handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: "); - } + cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_); + // TODO(csigg): This is a no-op (error is not reported). Same below. + CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: "); } cudnnDropoutDescriptor_t handle() const { @@ -1169,8 +990,7 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> { } private: - CUDAExecutor* parent_; - cudnnDropoutDescriptor_t handle_; + cudnnDropoutDescriptor_t handle_; // Owned. float dropout_; uint64 seed_; SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor); @@ -1180,10 +1000,10 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> { public: typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion; typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions; - CudnnRnnParamsDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, + CudnnRnnParamsDescriptor(const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc); ~CudnnRnnParamsDescriptor() { - cudnnStatus_t status = wrap::cudnnDestroyFilterDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor"); } cudnnFilterDescriptor_t handle() const { @@ -1202,7 +1022,6 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> { private: int GetRegionCountPerLayer() const; - CUDAExecutor* parent_; cudnnFilterDescriptor_t handle_; const CudnnRnnDescriptor* rnn_desc_; int64 params_size_in_bytes_; @@ -1211,19 +1030,20 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> { SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor); }; +} // namespace + class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { public: - CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, - int num_layers, int hidden_size, int input_size, - int batch_size, cudnnRNNInputMode_t input_mode, + CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size, + int input_size, int batch_size, + cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, ScratchAllocator* state_allocator) - : parent_(parent), - rnn_desc_(nullptr), + : rnn_desc_(nullptr), num_layers_(num_layers), hidden_size_(hidden_size), input_size_(input_size), @@ -1238,21 +1058,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { compute_type_(compute_type), algorithm_config_(algorithm_config) { // Create the dropout handle. - cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor( - parent, cudnn_handle, dropout, seed, state_allocator)); + cudnn_dropout_desc_.reset( + new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator)); if (!cudnn_dropout_desc_->ok()) { SetFailure(cudnn_dropout_desc_->Status()); return; } // Create the RNN handle - cudnnStatus_t status = wrap::cudnnCreateRNNDescriptor(parent_, &rnn_desc_); + cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_); CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor"); #if CUDNN_VERSION >= 6000 // TODO: allow the user to choose an algorithm. rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm()); - status = wrap::cudnnSetRNNDescriptor_v6( - parent, cudnn_handle, /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, + status = cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), /*inputMode=*/input_mode, /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type); @@ -1264,26 +1084,25 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { CHECK_GE(batch_size_, 0); - status = wrap::cudnnCreatePersistentRNNPlan( - parent, rnn_desc_, batch_size_, data_type_, &rnn_plan_); + status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_, + &rnn_plan_); CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan."); - status = wrap::cudnnSetPersistentRNNPlan(parent, rnn_desc_, rnn_plan_); + status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_); CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan."); } #else CHECK(algorithm_config_.is_default()) << "Non-default algorithm not supported for CUDA version < 6.0"; - status = wrap::cudnnSetRNNDescriptor( - parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/, - num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/, - input_mode /*inputMode*/, direction_mode /*direction*/, - rnn_mode /*mode*/, compute_type /*dataType*/); + status = cudnnSetRNNDescriptor( + /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, + /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), + /*inputMode=*/input_mode, /*direction=*/direction_mode, + /*mode=*/rnn_mode, /*dataType=*/compute_type); CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor"); #endif // Create the params handle. - cudnn_params_desc_.reset( - new CudnnRnnParamsDescriptor(parent, cudnn_handle, *this)); + cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this)); if (!cudnn_params_desc_->ok()) { SetFailure(cudnn_params_desc_->Status()); return; @@ -1295,11 +1114,11 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { cudnnStatus_t status; #if CUDNN_VERSION >= 6000 if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) { - status = wrap::cudnnDestroyPersistentRNNPlan(parent_, rnn_plan_); + status = cudnnDestroyPersistentRNNPlan(rnn_plan_); CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan."); } #endif - status = wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_); + status = cudnnDestroyRNNDescriptor(rnn_desc_); CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); } } @@ -1308,11 +1127,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { cudnnMathType_t math_type = (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); if (RnnTensorOpMathEnabled()) { - cudnnStatus_t status = - wrap::cudnnSetRNNMatrixMathType(parent_, rnn_desc_, math_type); + cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type); if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn RNN math type: " - << ToString(status); + LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status); } } #endif @@ -1354,7 +1171,6 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { } private: - CUDAExecutor* parent_; cudnnRNNDescriptor_t rnn_desc_; int num_layers_; int hidden_size_; @@ -1377,30 +1193,28 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> { SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor); }; +namespace { + CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( - CUDAExecutor* parent, cudnnHandle_t cudnn_handle, - const CudnnRnnDescriptor& rnn_desc) - : parent_(parent), - handle_(nullptr), - rnn_desc_(&rnn_desc), - params_size_in_bytes_(0) { + const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc) + : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) { cudnnTensorDescriptor_t input_desc = nullptr; { // Query the params size. - auto status = wrap::cudnnCreateTensorDescriptor(parent, &input_desc); + auto status = cudnnCreateTensorDescriptor(&input_desc); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor"); int dims[] = {1, rnn_desc.input_size(), 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = wrap::cudnnSetTensorNdDescriptor( - parent, input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/, - sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, - strides /*strideA*/); + status = cudnnSetTensorNdDescriptor( + /*tensorDesc=*/input_desc, rnn_desc.data_type() /*dataType*/, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims, + /*strideA=*/strides); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor"); size_t params_size = 0; - status = wrap::cudnnGetRNNParamsSize( - parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, - input_desc /*xDesc*/, ¶ms_size /*sizeInBytes*/, + status = cudnnGetRNNParamsSize( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + /*xDesc=*/input_desc, /*sizeInBytes=*/¶ms_size, rnn_desc.data_type() /*dataType*/); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size"); params_size_in_bytes_ = static_cast<int64>(params_size); @@ -1408,13 +1222,13 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( { // Create the params descriptor. - auto status = wrap::cudnnCreateFilterDescriptor(parent, &handle_); + auto status = cudnnCreateFilterDescriptor(&handle_); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor"); int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1}; - status = wrap::cudnnSetFilterNdDescriptor( - parent, handle_ /*filterDesc*/, rnn_desc.data_type() /*dataType*/, - CUDNN_TENSOR_NCHW /*format*/, sizeof(dims) / sizeof(dims[0]) /*nbDims*/, - dims /*filterDimA*/); + status = cudnnSetFilterNdDescriptor( + /*filterDesc=*/handle_, rnn_desc.data_type() /*dataType*/, + /*format=*/CUDNN_TENSOR_NCHW, sizeof(dims) / sizeof(dims[0]) /*nbDims*/, + /*filterDimA=*/dims); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor"); } @@ -1422,8 +1236,7 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( // Create the weights and biases into the params buffer int region_count_per_layer = GetRegionCountPerLayer(); cudnnFilterDescriptor_t region_desc_handle = nullptr; - auto status = - wrap::cudnnCreateFilterDescriptor(parent, ®ion_desc_handle); + auto status = cudnnCreateFilterDescriptor(®ion_desc_handle); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor"); const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL ? rnn_desc.num_layers() @@ -1433,21 +1246,21 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( for (int type = 0; type < 2; type++) { void* offset = nullptr; if (type == 0) { - status = wrap::cudnnGetRNNLinLayerMatrixParams( - parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, - layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/, - nullptr /*w*/, region /*linLayerID*/, - region_desc_handle /*linLayerMatDesc*/, - &offset /*linLayerMat*/); + status = cudnnGetRNNLinLayerMatrixParams( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, + /*w=*/nullptr, /*linLayerID=*/region, + /*linLayerMatDesc=*/region_desc_handle, + /*linLayerMat=*/&offset); CUDNN_RETURN_IF_FAIL( status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams"); } else { - status = wrap::cudnnGetRNNLinLayerBiasParams( - parent, cudnn_handle /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/, - layer /*layer*/, input_desc /*xDesc*/, handle_ /*wDesc*/, - nullptr /*w*/, region /*linLayerID*/, - region_desc_handle /*linLayerBiasDesc*/, - &offset /*linLayerBias*/); + status = cudnnGetRNNLinLayerBiasParams( + cudnn.handle() /*rnnDesc*/, rnn_desc.handle() /*rnnDesc*/, + /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, + /*w=*/nullptr, /*linLayerID=*/region, + /*linLayerBiasDesc=*/region_desc_handle, + /*linLayerBias=*/&offset); CUDNN_RETURN_IF_FAIL( status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams"); } @@ -1455,15 +1268,15 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( cudnnDataType_t data_type; cudnnTensorFormat_t tensor_format; int n_dims; - status = wrap::cudnnGetFilterNdDescriptor( - parent, region_desc_handle /*filterDesc*/, + status = cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle, sizeof(dims) / sizeof(dims[0]) /*nbDimsRequested*/, - &data_type /*dataType*/, &tensor_format /*format*/, - &n_dims /*nbDims*/, dims /*filterDimA*/); + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description"); int64 size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(rnn_desc.data_type()); - auto region = ParamsRegion{reinterpret_cast<int64>(offset), size}; + ParamsRegion region = {reinterpret_cast<int64>(offset), size}; if (type == 0) { weights_.push_back(region); } else { @@ -1472,13 +1285,13 @@ CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( } } } - status = wrap::cudnnDestroyFilterDescriptor(parent, region_desc_handle); + status = cudnnDestroyFilterDescriptor(region_desc_handle); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor"); } { // Release the dummy input tensor descriptor. - auto status = wrap::cudnnDestroyTensorDescriptor(parent, input_desc); + auto status = cudnnDestroyTensorDescriptor(input_desc); CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor"); } } @@ -1498,6 +1311,8 @@ int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const { } } +} // namespace + class CudnnRnnSequenceTensorDescriptor : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> { public: @@ -1517,14 +1332,14 @@ class CudnnRnnSequenceTensorDescriptor SetFailure(port::Status(port::error::UNKNOWN, error_msg)); return; } - cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle); + cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle); CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = wrap::cudnnSetTensorNdDescriptor( - parent, handle /*tensorDesc*/, data_type /*dataType*/, - sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, - strides /*strideA*/); + status = cudnnSetTensorNdDescriptor( + /*tensorDesc=*/handle, /*dataType=*/data_type, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims, + /*strideA=*/strides); CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); // Replicate handle across the number of steps. handles_.assign(seq_length, handle); @@ -1532,8 +1347,7 @@ class CudnnRnnSequenceTensorDescriptor ~CudnnRnnSequenceTensorDescriptor() override { // Only the first one needs to be destroyed. All others are the same. - cudnnStatus_t status = - wrap::cudnnDestroyTensorDescriptor(parent_, handles_[0]); + cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]); CUDNN_RETURN_IF_FAIL(status, "Failed to destroy sequence tensor descriptor"); } @@ -1570,21 +1384,20 @@ class CudnnRnnStateTensorDescriptor batch_size_(batch_size), data_size_(data_size), data_type_(data_type) { - cudnnStatus_t status = wrap::cudnnCreateTensorDescriptor(parent, &handle_); + cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); int dims[] = {num_layers, batch_size, data_size}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = wrap::cudnnSetTensorNdDescriptor( - parent, handle_ /*tensorDesc*/, data_type /*dataType*/, - sizeof(dims) / sizeof(dims[0]) /*nbDims*/, dims /*dimA*/, - strides /*strideA*/); + status = cudnnSetTensorNdDescriptor( + /*tensorDesc=*/handle_, /*dataType=*/data_type, + sizeof(dims) / sizeof(dims[0]) /*nbDims*/, /*dimA=*/dims, + /*strideA=*/strides); CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); } ~CudnnRnnStateTensorDescriptor() override { if (!handle_) { - cudnnStatus_t status = - wrap::cudnnDestroyTensorDescriptor(parent_, handle_); + cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor"); } } @@ -1679,13 +1492,13 @@ bool ExtractAndCheckRnnForward( return true; } -bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, +bool CheckRNNParameterSize(const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; - cudnnStatus_t status = wrap::cudnnGetRNNParamsSize( - parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, - input_desc.handles()[0] /*xDesc*/, ¶ms_size_in_bytes /*sizeInBytes*/, + cudnnStatus_t status = cudnnGetRNNParamsSize( + /*handle=*/cudnn.handle(), rnn_desc.handle() /*rnnDesc*/, + input_desc.handles()[0] /*xDesc*/, /*sizeInBytes=*/¶ms_size_in_bytes, rnn_desc.data_type() /*dataType*/); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "Unable to check RNN param size: " << ToString(status); @@ -1695,18 +1508,17 @@ bool CheckRNNParameterSize(CUDAExecutor* parent, cudnnHandle_t cudnn_handle, rnn_desc.ParamsSizeInBytes(); } -bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent, - cudnnHandle_t cudnn_handle, +bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, ScratchAllocator* workspace_allocator, DeviceMemory<uint8>* workspace) { // Query the workspace size. size_t workspace_size_in_bytes = 0; - cudnnStatus_t status = wrap::cudnnGetRNNWorkspaceSize( - parent, cudnn_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/, - input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/, - &workspace_size_in_bytes /*sizeInBytes*/); + cudnnStatus_t status = cudnnGetRNNWorkspaceSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), + /*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(), + /*sizeInBytes=*/&workspace_size_in_bytes); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "Unable to query workspace size: " << ToString(status); return false; @@ -1758,25 +1570,18 @@ bool CudnnSupport::DoRnnForwardImpl( return false; } - // check params size - mutex_lock lock{dnn_handle_mutex_}; - auto set_stream_status = - wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (set_stream_status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " - << ToString(set_stream_status); - } + auto cudnn = cudnn_->GetHandle(parent_, stream); - if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc, - input_desc)) { + // check params size + if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { LOG(ERROR) << "Invalid parameters"; return false; } // create the workspace DeviceMemory<uint8> workspace; - if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc, - input_desc, workspace_allocator, &workspace)) { + if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; return false; } @@ -1786,11 +1591,10 @@ bool CudnnSupport::DoRnnForwardImpl( DeviceMemory<uint8> reserve_space; if (is_training) { size_t reserve_space_size_in_bytes = 0; - cudnnStatus_t status = wrap::cudnnGetRNNTrainingReserveSize( - parent_, ToHandle(dnn_handle_) /*handle*/, - rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, - input_desc.handles() /*xDesc*/, - &reserve_space_size_in_bytes /*sizeInBytes*/); + cudnnStatus_t status = cudnnGetRNNTrainingReserveSize( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + /*seqLength=*/model_dims.seq_length, input_desc.handles() /*xDesc*/, + /*sizeInBytes=*/&reserve_space_size_in_bytes); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "Unable to query reserve space size: " << ToString(status); return false; @@ -1825,30 +1629,28 @@ bool CudnnSupport::DoRnnForwardImpl( // make the forward call cudnnStatus_t status; if (!is_training) { - status = wrap::cudnnRNNForwardInference( - this, stream, ToHandle(dnn_handle_) /*handle*/, - rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, - input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, - input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, - input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/, - rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/, - output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/, - output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/, - output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/, - workspace.opaque() /*workspace*/, + status = cudnnRNNForwardInference( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, + input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, + input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/, + input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/, + params.opaque() /*w*/, output_desc.handles() /*yDesc*/, + output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/, + output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/, + output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/, workspace.size() /*workSpaceSizeInBytes*/); } else { - status = wrap::cudnnRNNForwardTraining( - this, stream, ToHandle(dnn_handle_) /*handle*/, - rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, - input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, - input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, - input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/, - rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/, - output_desc.handles() /*yDesc*/, output_data->opaque() /*y*/, - output_h_desc.handle() /*hyDesc*/, output_h_data->opaque() /*hy*/, - output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/, - workspace.opaque() /*workspace*/, + status = cudnnRNNForwardTraining( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, + input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, + input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/, + input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/, + params.opaque() /*w*/, output_desc.handles() /*yDesc*/, + output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/, + output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/, + output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/, workspace.size() /*workSpaceSizeInBytes*/, reserve_space.opaque() /*reserveSpace*/, reserve_space.size() /*reserveSpaceSizeInBytes*/); @@ -1914,25 +1716,18 @@ bool CudnnSupport::DoRnnBackwardImpl( return false; } - // check params size - mutex_lock lock{dnn_handle_mutex_}; - auto set_stream_status = - wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (set_stream_status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " - << ToString(set_stream_status); - } + auto cudnn = cudnn_->GetHandle(parent_, stream); - if (!CheckRNNParameterSize(parent_, ToHandle(dnn_handle_), rnn_desc, - input_desc)) { + // check params size + if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { LOG(ERROR) << "Invalid parameters"; return false; } // create the workspace DeviceMemory<uint8> workspace; - if (!CreateRnnWorkspace(stream, parent_, ToHandle(dnn_handle_), rnn_desc, - input_desc, workspace_allocator, &workspace)) { + if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; return false; } @@ -1952,12 +1747,11 @@ bool CudnnSupport::DoRnnBackwardImpl( } } // make the backward data call - cudnnStatus_t status = wrap::cudnnRNNBackwardData( - this, stream, ToHandle(dnn_handle_) /*handle*/, - rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, - output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/, - output_desc.handles() /*dyDesc*/, output_backprop_data.opaque() /*dy*/, - output_h_desc.handle() /*dhyDesc*/, + cudnnStatus_t status = cudnnRNNBackwardData( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/, + output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/, + output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/, output_h_backprop_data.opaque() /*dhy*/, output_c_desc.handle() /*dcyDesc*/, output_c_backprop_data.opaque() /*dcy*/, @@ -1985,13 +1779,12 @@ bool CudnnSupport::DoRnnBackwardImpl( // Clear the dw to zeros. stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); // make the backward weight call - status = wrap::cudnnRNNBackwardWeights( - this, stream, ToHandle(dnn_handle_) /*handle*/, - rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/, - input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/, - input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/, - output_desc.handles() /*yDesc*/, output_data.opaque() /*y*/, - workspace.opaque() /*workspace*/, + status = cudnnRNNBackwardWeights( + cudnn.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/, + model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/, + input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/, + input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/, + output_data.opaque() /*y*/, workspace.opaque() /*workspace*/, workspace.size() /*workSpaceSizeInBytes*/, rnn_desc.params_handle() /*dwDesc*/, params_backprop_data->opaque() /*dw*/, @@ -2029,13 +1822,15 @@ CudnnSupport::createRnnDescriptor( const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, ScratchAllocator* state_allocator) { #if CUDNN_VERSION >= 5000 - mutex_lock lock{dnn_handle_mutex_}; + // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's + // not enqueueing anything into a stream, we pass in the null stream. + auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor( - parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size, - batch_size, ToCudnnRnnInputMode(input_mode), - ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), - ToCudnnDataType(data_type), GetRnnComputeType(data_type), - algorithm_config, dropout, seed, state_allocator)); + cudnn, num_layers, hidden_size, input_size, batch_size, + ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), + ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), + GetRnnComputeType(data_type), algorithm_config, dropout, seed, + state_allocator)); if (!rnn_desc->ok()) { return rnn_desc->Status(); } @@ -2046,7 +1841,7 @@ CudnnSupport::createRnnDescriptor( port::StrCat("createRnnDescriptor needs at least Cudnn 5.0 to work. ", "Current Cudnn version: ", CUDNN_VERSION, ". "); LOG(ERROR) << error_msg; - return port::Status{port::error::UNIMPLEMENTED, error_msg}; + return port::Status(port::error::UNIMPLEMENTED, error_msg); #endif // CUDNN_VERSION } @@ -2069,7 +1864,7 @@ CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, "createRnnSequenceTensorDescriptor needs at least Cudnn 5.0 to work. ", "Current Cudnn version: ", CUDNN_VERSION, ". "); LOG(ERROR) << error_msg; - return port::Status{port::error::UNIMPLEMENTED, error_msg}; + return port::Status(port::error::UNIMPLEMENTED, error_msg); #endif // CUDNN_VERSION } @@ -2091,7 +1886,7 @@ CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, "createRnnStateTensorDescriptor needs at least Cudnn 5.0 to work. ", "Current Cudnn version: ", CUDNN_VERSION, ". "); LOG(ERROR) << error_msg; - return port::Status{port::error::UNIMPLEMENTED, error_msg}; + return port::Status(port::error::UNIMPLEMENTED, error_msg); #endif // CUDNN_VERSION } @@ -2393,35 +2188,26 @@ bool CudnnSupport::DoRnnBackward( namespace { inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo( - Stream* stream, CUDAExecutor* parent, void* dnn_handle, - const ScopedTensorDescriptor& input_nd, + const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd, const ScopedFilterDescriptor& filter, const ScopedConvolutionDescriptor& conv, const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit, - ScratchAllocator* scratch_allocator) { + size_t memory_limit_bytes) { cudnnConvolutionFwdPreference_t preference = specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } cudnnConvolutionFwdAlgo_t algo_to_use; - auto status = wrap::cudnnGetConvolutionForwardAlgorithm( - parent, ToHandle(dnn_handle), input_nd.handle(), filter.handle(), - conv.handle(), output_nd.handle(), preference, memory_limit_bytes, - &algo_to_use); + auto status = cudnnGetConvolutionForwardAlgorithm( + cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), + output_nd.handle(), preference, memory_limit_bytes, &algo_to_use); CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable algorithm for doing forward convolution"; return algo_to_use; } dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( - Stream* stream, CUDAExecutor* parent, void* dnn_handle, + Stream* stream, const CudnnHandle& cudnn, const dnn::AlgorithmConfig& algorithm_config, bool is_profiling, const ScopedTensorDescriptor& input_nd, const ScopedFilterDescriptor& filter, @@ -2432,19 +2218,29 @@ dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( bool use_tensor_ops; if (algorithm_config.algorithm().is_default()) { use_tensor_ops = true; + + auto memory_limit_bytes = + scratch_allocator == nullptr + ? 0 + : scratch_allocator->GetMemoryLimitInBytes(stream); + if (memory_limit_bytes < 0) { + memory_limit_bytes = 0; + } + algo = GetCudnnConvolutionForwardAlgo( - stream, parent, dnn_handle, input_nd, filter, conv, output_nd, + cudnn, input_nd, filter, conv, output_nd, /*specify_workspace_limit=*/scratch_allocator != nullptr, - scratch_allocator); + memory_limit_bytes); } else { use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); algo = ToConvForwardAlgo(algorithm_config.algorithm()); } size_t size_in_bytes; - auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize( - parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(), - /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*destDesc=*/output_nd.handle(), /*algo=*/algo, + auto status = cudnnGetConvolutionForwardWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), + /*yDesc=*/output_nd.handle(), /*algo=*/algo, /*sizeInBytes=*/&size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) { @@ -2484,8 +2280,8 @@ dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( if (algorithm_config.algorithm_no_scratch().is_default()) { use_tensor_ops = true; algo = GetCudnnConvolutionForwardAlgo( - stream, parent, dnn_handle, input_nd, filter, conv, output_nd, - /*specify_workspace_limit=*/false, nullptr); + cudnn, input_nd, filter, conv, output_nd, + /*specify_workspace_limit=*/false, 0); } else { use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch()); @@ -2614,11 +2410,12 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type); } } + } // namespace template <class T> bool CudnnSupport::DoConvolveImpl( - Stream* stream, const BatchDescriptor& batch_descriptor, + Stream* stream, const BatchDescriptor& input_descriptor, const DeviceMemory<T>& input_data, const FilterDescriptor& filter_descriptor, const DeviceMemory<T>& filter_data, @@ -2628,18 +2425,13 @@ bool CudnnSupport::DoConvolveImpl( const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { cudnnDataType_t cudnn_type = GetCudnnDataType<T>(); - ScopedTensorDescriptor input_nd{parent_, batch_descriptor, cudnn_type}; - ScopedTensorDescriptor output_nd{parent_, output_descriptor, cudnn_type}; - ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor, - cudnn_type}; - ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, - GetConvComputeType<T>()}; - - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); - } + ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); + ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type); + ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); + ScopedConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType<T>()); + + auto cudnn = cudnn_->GetHandle(parent_, stream); // Alpha is the scaling factor for input. float falpha = 1.0; double dalpha = 1.0; @@ -2660,42 +2452,41 @@ bool CudnnSupport::DoConvolveImpl( // GetCudnnConvolutionForwardAlgorithm(). if (algorithm_config.algorithm().is_default()) { // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = - [&](bool specify_limit) SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) { - cudnnConvolutionFwdPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } + auto get_algorithm = [&](bool specify_limit) { + cudnnConvolutionFwdPreference_t preference = + specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - cudnnConvolutionFwdAlgo_t algo_to_use; - status = wrap::cudnnGetConvolutionForwardAlgorithm( - parent_, ToHandle(dnn_handle_), input_nd.handle(), - filter.handle(), conv.handle(), output_nd.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) - << "Unable to find a suitable " - "algorithm for doing forward " - "convolution"; - return algo_to_use; - }; + auto memory_limit_bytes = + scratch_allocator == nullptr + ? 0 + : scratch_allocator->GetMemoryLimitInBytes(stream); + if (memory_limit_bytes < 0) { + memory_limit_bytes = 0; + } + + cudnnConvolutionFwdAlgo_t algo_to_use; + auto status = cudnnGetConvolutionForwardAlgorithm( + cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), + output_nd.handle(), + /*preference=*/preference, + /*memoryLimitInBytes=*/memory_limit_bytes, + /*algo=*/&algo_to_use); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " + "algorithm for doing forward " + "convolution"; + return algo_to_use; + }; algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); use_tensor_ops = true; if (scratch_allocator != nullptr) { size_t size_in_bytes; - status = wrap::cudnnGetConvolutionForwardWorkspaceSize( - parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(), - /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*destDesc=*/output_nd.handle(), /*algo=*/algo, + auto status = cudnnGetConvolutionForwardWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), + /*yDesc=*/output_nd.handle(), /*algo=*/algo, /*sizeInBytes=*/&size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { @@ -2727,10 +2518,11 @@ bool CudnnSupport::DoConvolveImpl( use_tensor_ops = algotype.tensor_ops_enabled(); conv.set_use_tensor_op_math(use_tensor_ops); size_t size_in_bytes; - status = wrap::cudnnGetConvolutionForwardWorkspaceSize( - parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(), - /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*destDesc=*/output_nd.handle(), /*algo=*/algo, + auto status = cudnnGetConvolutionForwardWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), + /*yDesc=*/output_nd.handle(), /*algo=*/algo, /*sizeInBytes=*/&size_in_bytes); if (status != CUDNN_STATUS_SUCCESS) { if (is_profiling) { @@ -2785,8 +2577,8 @@ bool CudnnSupport::DoConvolveImpl( return false; } } - status = wrap::cudnnConvolutionForward( - this, stream, ToHandle(dnn_handle_), + auto status = cudnnConvolutionForward( + cudnn.handle(), /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), @@ -2840,30 +2632,22 @@ bool CudnnSupport::DoFusedConvolveImpl( "supported for cuDNN version >= 6"; return false; #else - ScopedTensorDescriptor conv_input_nd{ - parent_, conv_input_descriptor, - static_cast<cudnnDataType_t>(cudnn_data_type)}; - ScopedTensorDescriptor output_nd{ - parent_, output_descriptor, - static_cast<cudnnDataType_t>(cudnn_data_type)}; - ScopedFilterDescriptor filter{parent_, filter_descriptor, - conv_input_descriptor, - static_cast<cudnnDataType_t>(cudnn_data_type)}; - ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, CUDNN_DATA_FLOAT}; - ScopedConvolutionDescriptor conv{ - parent_, convolution_descriptor, - static_cast<cudnnDataType_t>(cudnn_compute_type)}; - - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - CHECK(status == CUDNN_STATUS_SUCCESS) - << "failed to set stream for cudnn handle: " << ToString(status); - + ScopedTensorDescriptor conv_input_nd( + conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); + ScopedTensorDescriptor output_nd( + output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type)); + ScopedFilterDescriptor filter(filter_descriptor, + static_cast<cudnnDataType_t>(cudnn_data_type)); + ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT); + ScopedConvolutionDescriptor conv( + convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type)); + + auto cudnn = cudnn_->GetHandle(parent_, stream); const bool is_profiling = output_profile_result != nullptr; DeviceMemory<uint8> scratch; dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm( - stream, parent_, dnn_handle_, algorithm_config, is_profiling, - conv_input_nd, filter, conv, output_nd, scratch_allocator, &scratch); + stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter, + conv, output_nd, scratch_allocator, &scratch); if (algotype.is_default()) { if (!is_profiling) { LOG(ERROR) << "No suitable algorithm found"; @@ -2897,9 +2681,8 @@ bool CudnnSupport::DoFusedConvolveImpl( // activation descriptor. Note that this will change the nan propagation // behavior from separate conv, bias, and relu (which by default is // CUDNN_PROPAGATE_NAN. - ScopedActivationDescriptor activation_desc{parent_, activation_mode, - CUDNN_NOT_PROPAGATE_NAN, - output_descriptor.value_max()}; + ScopedActivationDescriptor activation_desc( + activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max()); auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque() : side_input_data.opaque(); @@ -2920,8 +2703,9 @@ bool CudnnSupport::DoFusedConvolveImpl( << "\noutput_nd.handle() = " << output_nd.handle() << "\noutput_data->opaque() = " << output_data->opaque(); - status = wrap::cudnnConvolutionBiasActivationForward( - this, stream, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale, + auto status = cudnnConvolutionBiasActivationForward( + cudnn.handle(), + /*alpha1=*/&conv_input_scale, /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(), @@ -3125,17 +2909,9 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var, bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var, std::function<void()> inv_var_to_var) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - - ScopedTensorDescriptor x_descriptor{parent_, x_desc, - ToCudnnDataType(input_data_type)}; - ScopedTensorDescriptor scale_offset_descriptor{ - parent_, scale_offset_desc, ToCudnnDataType(scale_data_type)}; + ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type)); + ScopedTensorDescriptor scale_offset_descriptor( + scale_offset_desc, ToCudnnDataType(scale_data_type)); cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7000 if (BatchnormSpatialPersistentEnabled() && is_training) { @@ -3144,7 +2920,9 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( #endif float one = 1.0; float zero = 0.0; + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = CUDNN_STATUS_SUCCESS; if (is_training) { CHECK_EQ(batch_mean->is_null(), batch_var->is_null()) << "batch_mean and batch_var must both be null or both be non-null"; @@ -3161,11 +2939,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( batch_var_opaque = nullptr; } - status = wrap::cudnnBatchNormalizationForwardTraining( - this, stream, ToHandle(dnn_handle_), mode, &one, &zero, - x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), - scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0, - batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(), + status = cudnnBatchNormalizationForwardTraining( + cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), + x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), + scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque, + batch_var_opaque, epsilon, saved_mean->opaque(), saved_inv_var->opaque()); #if CUDNN_VERSION < 5000 CHECK(inv_var_to_var); @@ -3178,11 +2956,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( #else const void* maybe_inv_var = estimated_variance.opaque(); #endif - status = wrap::cudnnBatchNormalizationForwardInference( - this, stream, ToHandle(dnn_handle_), mode, &one, &zero, - x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), - scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), - estimated_mean.opaque(), maybe_inv_var, epsilon); + status = cudnnBatchNormalizationForwardInference( + cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), + x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), + scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var, + epsilon); } if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue forward batch normalization on stream: " @@ -3229,18 +3007,10 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop, DeviceMemory<U>* offset_backprop) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - - ScopedTensorDescriptor x_descriptor{ - parent_, x_desc, static_cast<cudnnDataType_t>(cudnn_input_type)}; - ScopedTensorDescriptor scale_offset_descriptor{ - parent_, scale_offset_desc, - static_cast<cudnnDataType_t>(cudnn_scale_type)}; + ScopedTensorDescriptor x_descriptor( + x_desc, static_cast<cudnnDataType_t>(cudnn_input_type)); + ScopedTensorDescriptor scale_offset_descriptor( + scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type)); cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7000 if (BatchnormSpatialPersistentEnabled()) { @@ -3250,10 +3020,12 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( float one = 1.0; float zero = 0.0; - status = wrap::cudnnBatchNormalizationBackward( - this, stream, ToHandle(dnn_handle_), mode, &one, &zero, &one, &zero, - x_descriptor.handle(), x.opaque(), x_descriptor.handle(), - y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(), + auto cudnn = cudnn_->GetHandle(parent_, stream); + + auto status = cudnnBatchNormalizationBackward( + cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(), + x.opaque(), x_descriptor.handle(), y_backprop.opaque(), + x_descriptor.handle(), x_backprop->opaque(), scale_offset_descriptor.handle(), scale.opaque(), scale_backprop->opaque(), offset_backprop->opaque(), epsilon, mean.opaque(), inv_var.opaque()); @@ -3416,11 +3188,21 @@ bool CudnnSupport::DoFusedConvolve( #endif } -template<class T> -DeviceMemory<T> CudnnSupport::MaybeTransformLayout( - Stream* stream, - BatchDescriptor* output_descriptor, - DeviceMemory<T> backward_output_data, +namespace { +// NOTE(keveman): Temporary data layout transformation until cuDNN supports +// kBatchYXDepth for backward pass. This function allocates temporary memory, +// lays out the source data into the temporary but in the kBatchDepthXY +// layout, and returns the temporary memory. The caller is responsible for +// deallocating the temporary. Since the allocation is done using Stream's +// AllocateTemporaryMemory, a later BlockHostUntilDone could be used for +// deallocation. +// +// transform_scratch is populated with a legitimate temporary allocation iff +// the original output data needs to be transformed. +template <class T> +DeviceMemory<T> MaybeTransformLayout( + Stream* stream, const CudnnHandle& cudnn, + BatchDescriptor* output_descriptor, DeviceMemory<T> backward_output_data, std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) { if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) { return backward_output_data; @@ -3433,15 +3215,14 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout( transformed_output_descriptor.CloneFrom(*output_descriptor); transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX); cudnnDataType_t cudnn_type = GetCudnnDataType<T>(); - ScopedTensorDescriptor orig_out_back_nd{parent_, *output_descriptor, - cudnn_type}; - ScopedTensorDescriptor transformed_out_back_nd{ - parent_, transformed_output_descriptor, cudnn_type}; + ScopedTensorDescriptor orig_out_back_nd(*output_descriptor, cudnn_type); + ScopedTensorDescriptor transformed_out_back_nd(transformed_output_descriptor, + cudnn_type); float alpha = 1.0f; float beta = 0.0f; - auto status = wrap::cudnnTransformTensor( - this, stream, ToHandle(dnn_handle_), &alpha, orig_out_back_nd.handle(), + auto status = cudnnTransformTensor( + cudnn.handle(), &alpha, orig_out_back_nd.handle(), backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(), (*transform_scratch)->mutable_device_memory()->opaque()); @@ -3451,6 +3232,7 @@ DeviceMemory<T> CudnnSupport::MaybeTransformLayout( output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX); return (*transform_scratch)->device_memory(); } +} // namespace bool CudnnSupport::DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, @@ -3459,21 +3241,15 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& output_desc, dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); - } - float beta = 0.0f; ScopedTensorDescriptor input_tensor_desc( - parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout())); + input_desc, ToCudnnDataType(input_type, input_desc.layout())); ScopedTensorDescriptor output_tensor_desc( - parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout())); - status = wrap::cudnnTransformTensor( - this, stream, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(), - input_data.opaque(), &beta, output_tensor_desc.handle(), - output_data->opaque()); + output_desc, ToCudnnDataType(output_type, output_desc.layout())); + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnTransformTensor( + cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), + &beta, output_tensor_desc.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "Could not transform a tensor with layout " << input_desc.ToString() << " and data type " @@ -3487,8 +3263,7 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, template <class T> bool CudnnSupport::DoConvolveBackwardDataImpl( - Stream* stream, - const FilterDescriptor& filter_descriptor, + Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory<T>& filter_data, const BatchDescriptor& output_descriptor_in, DeviceMemory<T> backward_output_data, @@ -3497,12 +3272,6 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( DeviceMemory<T>* backward_input_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); - } - cudnnDataType_t cudnn_type = GetCudnnDataType<T>(); // Alpha is the scaling factor for input. float falpha = 1.0; @@ -3515,19 +3284,21 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta) : static_cast<void*>(&fbeta); + auto cudnn = cudnn_->GetHandle(parent_, stream); + // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. BatchDescriptor output_descriptor; output_descriptor.CloneFrom(output_descriptor_in); std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch; - backward_output_data = MaybeTransformLayout( - stream, &output_descriptor, backward_output_data, &transform_scratch); + backward_output_data = + MaybeTransformLayout(stream, cudnn, &output_descriptor, + backward_output_data, &transform_scratch); - ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; - ScopedTensorDescriptor in_back_nd{parent_, input_descriptor, cudnn_type}; - ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, - cudnn_type}; - ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, - GetConvComputeType<T>()}; + ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type); + ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type); + ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); + ScopedConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType<T>()); const bool is_profiling = output_profile_result != nullptr; cudnnConvolutionBwdDataAlgo_t algo; @@ -3535,8 +3306,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( if (algorithm_config.algorithm().is_default()) { // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED( - dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t { + auto get_algorithm = + [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t { cudnnConvolutionBwdDataPreference_t preference = specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; @@ -3549,8 +3320,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( memory_limit_bytes = 0; } cudnnConvolutionBwdDataAlgo_t algo_to_use; - cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardDataAlgorithm( - parent_, ToHandle(dnn_handle_), + cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm( + cudnn.handle(), /*filterDesc=*/filter.handle(), /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), @@ -3568,8 +3339,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( if (scratch_allocator != nullptr) { size_t size_in_bytes; - status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize( - parent_, ToHandle(dnn_handle_), + auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn.handle(), /*filterDesc=*/filter.handle(), /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), @@ -3605,8 +3376,8 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( algo = ToConvBackwardDataAlgo(algotype); conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); size_t size_in_bytes; - status = wrap::cudnnGetConvolutionBackwardDataWorkspaceSize( - parent_, ToHandle(dnn_handle_), + auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn.handle(), /*filterDesc=*/filter.handle(), /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), @@ -3663,23 +3434,24 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( } #if CUDNN_VERSION >= 5000 - status = wrap::cudnnConvolutionBackwardData( + auto status = + cudnnConvolutionBackwardData(cudnn.handle(), #else - status = wrap::cudnnConvolutionBackwardData_v3( + auto status = + cudnnConvolutionBackwardData_v3(cudnn.handle(), #endif - this, stream, ToHandle(dnn_handle_), - /*alpha=*/alpha, - /*filterDesc=*/filter.handle(), - /*filterData=*/filter_data.opaque(), - /*diffDesc=*/out_back_nd.handle(), - /*diffData=*/backward_output_data.opaque(), - /*convDesc=*/conv.handle(), - /*algo=*/algo, - /*workSpace=*/scratch.opaque(), - /*workSpaceSizeInBytes=*/scratch.size(), - /*beta=*/beta, - /*gradDesc=*/in_back_nd.handle(), - /*gradData=*/backward_input_data->opaque()); + /*alpha=*/alpha, + /*wDesc=*/filter.handle(), + /*w=*/filter_data.opaque(), + /*dyDesc=*/out_back_nd.handle(), + /*dy=*/backward_output_data.opaque(), + /*convDesc=*/conv.handle(), + /*algo=*/algo, + /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), + /*beta=*/beta, + /*dxDesc=*/in_back_nd.handle(), + /*dx=*/backward_input_data->opaque()); if (is_profiling) { timer->Stop(AsCUDAStream(stream)); if (status == CUDNN_STATUS_SUCCESS) { @@ -3767,12 +3539,6 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( DeviceMemory<T>* backward_filter_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); - } - cudnnDataType_t cudnn_type = GetCudnnDataType<T>(); // Alpha is the scaling factor for input. float falpha = 1.0; @@ -3785,19 +3551,21 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta) : static_cast<void*>(&fbeta); + auto cudnn = cudnn_->GetHandle(parent_, stream); + // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass. BatchDescriptor output_descriptor; output_descriptor.CloneFrom(output_descriptor_in); std::unique_ptr<TemporaryDeviceMemory<T>> transform_scratch; - backward_output_data = MaybeTransformLayout( - stream, &output_descriptor, backward_output_data, &transform_scratch); + backward_output_data = + MaybeTransformLayout(stream, cudnn, &output_descriptor, + backward_output_data, &transform_scratch); - ScopedTensorDescriptor out_back_nd{parent_, output_descriptor, cudnn_type}; - ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type}; - ScopedFilterDescriptor filter{parent_, filter_descriptor, input_descriptor, - cudnn_type}; - ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, - GetConvComputeType<T>()}; + ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type); + ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); + ScopedFilterDescriptor filter(filter_descriptor, cudnn_type); + ScopedConvolutionDescriptor conv(convolution_descriptor, + GetConvComputeType<T>()); const bool is_profiling = output_profile_result != nullptr; cudnnConvolutionBwdFilterAlgo_t algo; @@ -3809,8 +3577,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( // Lambda that retrieves the algorithm. // specify_limit will occur when we have a scratch allocator and it succeeds // in allocating; otherwise, we'll fall back to the "no workspace" version. - auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED( - dnn_handle_mutex_) { + auto get_algorithm = [&](bool specify_limit) { cudnnConvolutionBwdFilterPreference_t preference = specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; @@ -3824,8 +3591,8 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( } cudnnConvolutionBwdFilterAlgo_t algo_to_use; - cudnnStatus_t status = wrap::cudnnGetConvolutionBackwardFilterAlgorithm( - parent_, ToHandle(dnn_handle_), + cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm( + cudnn.handle(), /*srcDesc=*/input_nd.handle(), /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), @@ -3843,9 +3610,10 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( if (scratch_allocator != nullptr) { size_t size_in_bytes; - status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize( - parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(), - /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), + auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), /*algo=*/algo, /*sizeInBytes=*/&size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; @@ -3878,9 +3646,10 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); size_t size_in_bytes; - status = wrap::cudnnGetConvolutionBackwardFilterWorkspaceSize( - parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_nd.handle(), - /*diffDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), + auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), /*algo=*/algo, /*sizeInBytes=*/&size_in_bytes); if (status != CUDNN_STATUS_SUCCESS) { @@ -3934,11 +3703,13 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( } #if CUDNN_VERSION >= 5000 - status = wrap::cudnnConvolutionBackwardFilter( + auto status = cudnnConvolutionBackwardFilter( + cudnn.handle(), #else - status = wrap::cudnnConvolutionBackwardFilter_v3( + auto status = cudnnConvolutionBackwardFilter_v3( + cudnn.handle(), #endif - this, stream, ToHandle(dnn_handle_), /*alpha=*/alpha, + /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*diffDesc=*/out_back_nd.handle(), @@ -4033,25 +3804,19 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl( const DeviceMemory<T>& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory<T>* backward_bias_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status); - } - cudnnDataType_t cudnn_type = GetCudnnDataType<T>(); - ScopedTensorDescriptor input_nd{parent_, input_descriptor, cudnn_type}; - ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, cudnn_type}; + ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type); + ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type); // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. float beta = 0.0; - status = wrap::cudnnConvolutionBackwardBias( - this, stream, ToHandle(dnn_handle_), &alpha, input_nd.handle(), - input_data.opaque(), &beta, bias_nd.handle(), - backward_bias_data->opaque()); + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnConvolutionBackwardBias( + cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta, + bias_nd.handle(), backward_bias_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue backward convolution on stream: " << ToString(status); @@ -4227,8 +3992,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, const DeviceMemory<float>& biases, const dnn::BatchDescriptor& dimensions, DeviceMemory<float>* output_data) { - ScopedTensorDescriptor input_descriptor{parent_, dimensions, - CUDNN_DATA_FLOAT}; + ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT); BatchDescriptor bias_dimensions; bias_dimensions.set_count(1) @@ -4236,8 +4000,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, .set_height(1) .set_width(1) .set_layout(dnn::DataLayout::kBatchYXDepth); - ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions, - CUDNN_DATA_FLOAT}; + ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT); // cudnnAddTensor after R3 is in-place, so we need to copy input_data to // output_data before doing the addition, unless the input and @@ -4253,23 +4016,18 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, } } - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - const float alpha = 1.0f; const float beta = 1.0f; + auto cudnn = cudnn_->GetHandle(parent_, stream); + #if CUDNN_VERSION >= 5000 - status = wrap::cudnnAddTensor( + auto status = cudnnAddTensor( #else - status = wrap::cudnnAddTensor_v3( + auto status = cudnnAddTensor_v3( #endif - this, stream, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(), - biases.opaque(), &beta, input_descriptor.handle(), output_data->opaque()); + cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta, + input_descriptor.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "stream " << stream << " could not enqueue bias addition."; @@ -4285,16 +4043,9 @@ bool CudnnSupport::DoActivate(Stream* stream, const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data, uint64 options) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - #if CUDNN_VERSION >= 5000 - ScopedActivationDescriptor activation_desc{ - parent_, activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()}; + ScopedActivationDescriptor activation_desc( + activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max()); #else cudnnActivationMode_t mode; switch (activation_mode) { @@ -4324,20 +4075,22 @@ bool CudnnSupport::DoActivate(Stream* stream, } #endif - ScopedTensorDescriptor input_nd{parent_, dimensions, CUDNN_DATA_FLOAT}; + ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT); // Alpha is the input scaling factor. float alpha = 1.0; // Beta is the output scaling factor. float beta = 0.0; - status = wrap::cudnnActivationForward( - this, stream, ToHandle(dnn_handle_), + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = + cudnnActivationForward(cudnn.handle(), #if CUDNN_VERSION >= 5000 - activation_desc.handle(), + activation_desc.handle(), #else - mode, + mode, #endif - &alpha, input_nd.handle(), input_data.opaque(), &beta, input_nd.handle(), - output_data->opaque()); + &alpha, input_nd.handle(), input_data.opaque(), + &beta, input_nd.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "stream " << stream << " could not enqueue activation: " << ToString(status); @@ -4353,26 +4106,19 @@ bool CudnnSupport::DoPoolForward( const DeviceMemory<double>& input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemory<double>* output_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. double beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, - CUDNN_DATA_DOUBLE}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingForward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), - output_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue forward pooling on stream: " << ToString(status); @@ -4387,26 +4133,19 @@ bool CudnnSupport::DoPoolForward( const DeviceMemory<float>& input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemory<float>* output_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, - CUDNN_DATA_FLOAT}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingForward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), - output_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue forward pooling on stream: " << ToString(status); @@ -4421,25 +4160,18 @@ bool CudnnSupport::DoPoolForward( const DeviceMemory<Eigen::half>& input_data, const dnn::BatchDescriptor& output_dimensions, DeviceMemory<Eigen::half>* output_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingForward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), - output_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue forward pooling on stream: " << ToString(status); @@ -4456,27 +4188,21 @@ bool CudnnSupport::DoPoolBackward( const DeviceMemory<double>& output_data, const DeviceMemory<double>& input_diff_data, DeviceMemory<double>* output_diff_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. double beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_DOUBLE}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, - CUDNN_DATA_DOUBLE}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingBackward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - dest_desc.handle(), output_data.opaque(), dest_desc.handle(), - input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, - src_desc.handle(), output_diff_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue backward pooling on stream: " << ToString(status); @@ -4493,27 +4219,21 @@ bool CudnnSupport::DoPoolBackward( const DeviceMemory<float>& output_data, const DeviceMemory<float>& input_diff_data, DeviceMemory<float>* output_diff_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, - CUDNN_DATA_FLOAT}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingBackward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - dest_desc.handle(), output_data.opaque(), dest_desc.handle(), - input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, - src_desc.handle(), output_diff_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue backward pooling on stream: " << ToString(status); @@ -4530,26 +4250,21 @@ bool CudnnSupport::DoPoolBackward( const DeviceMemory<Eigen::half>& output_data, const DeviceMemory<Eigen::half>& input_diff_data, DeviceMemory<Eigen::half>* output_diff_data) { - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. float beta = 0.0; - ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF}; - ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF}; - ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; - status = wrap::cudnnPoolingBackward( - this, stream, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, - dest_desc.handle(), output_data.opaque(), dest_desc.handle(), - input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, - src_desc.handle(), output_diff_data->opaque()); + ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF); + ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); + ScopedPoolingDescriptor pooling_desc(pooling_dimensions); + + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to enqueue backward pooling on stream: " << ToString(status); @@ -4571,7 +4286,7 @@ bool CudnnSupport::DoNormalizeWithDimensions( const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) { // Check for unsupported modes. if (normalize_descriptor.wrap_around()) { - LOG(ERROR) << "CUDA LRN does not support wrap-around mode"; + LOG(ERROR) << "CUDA LRN does not support cudnn-around mode"; return false; } if (normalize_descriptor.segment_size()) { @@ -4579,26 +4294,21 @@ bool CudnnSupport::DoNormalizeWithDimensions( return false; } - // Launch the normalization. - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - - ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT}; - ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor}; + ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); + ScopedNormalizeDescriptor normalize(normalize_descriptor); // Alpha is the scaling factor for input. float alpha = 1.0f; // Beta is the scaling factor for output. float beta = 0.0f; - status = wrap::cudnnLRNCrossChannelForward( - this, stream, ToHandle(dnn_handle_), normalize.handle(), - CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), input_data.opaque(), - &beta, dims.handle(), output_data->opaque()); + auto cudnn = cudnn_->GetHandle(parent_, stream); + + // Launch the normalization. + auto status = cudnnLRNCrossChannelForward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, + dims.handle(), input_data.opaque(), &beta, dims.handle(), + output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward"; return false; @@ -4614,7 +4324,7 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( DeviceMemory<float>* raw_variable_gradient) { // Check for unsupported modes. if (normalize_descriptor.wrap_around()) { - LOG(ERROR) << "CUDA LRN does not support wrap-around mode"; + LOG(ERROR) << "CUDA LRN does not support cudnn-around mode"; return false; } if (normalize_descriptor.segment_size()) { @@ -4622,23 +4332,16 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( return false; } - mutex_lock lock{dnn_handle_mutex_}; - auto status = wrap::cudnnSetStream(this, stream, ToHandle(dnn_handle_)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); - return false; - } - - ScopedTensorDescriptor dims{parent_, dimensions, CUDNN_DATA_FLOAT}; - ScopedNormalizeDescriptor normalize{parent_, normalize_descriptor}; + ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT); + ScopedNormalizeDescriptor normalize(normalize_descriptor); float alpha = 1.0f; float beta = 0.0f; - status = wrap::cudnnLRNCrossChannelBackward( - this, stream, ToHandle(dnn_handle_), normalize.handle(), - CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dims.handle(), - normalized_data.opaque(), dims.handle(), + auto cudnn = cudnn_->GetHandle(parent_, stream); + auto status = cudnnLRNCrossChannelBackward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, + dims.handle(), normalized_data.opaque(), dims.handle(), normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), &beta, dims.handle(), raw_variable_gradient->opaque()); if (status != CUDNN_STATUS_SUCCESS) { @@ -4754,17 +4457,14 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( const FilterDescriptor& filter_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, dnn::BatchDescriptor* output_batch_descriptor) { - ScopedTensorDescriptor input_nd{parent_, batch_descriptor, CUDNN_DATA_FLOAT}; - ScopedFilterDescriptor filter{parent_, filter_descriptor, batch_descriptor, - CUDNN_DATA_FLOAT}; - ScopedConvolutionDescriptor conv{parent_, convolution_descriptor, - CUDNN_DATA_FLOAT}; + ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT); + ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT); + ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT); int dn = batch_descriptor.ndims() + 2; std::vector<int> dims(dn); // in BDYX - auto status = wrap::cudnnGetConvolutionNdForwardOutputDim( - parent_, conv.handle(), input_nd.handle(), filter.handle(), dn, - dims.data()); + auto status = cudnnGetConvolutionNdForwardOutputDim( + conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()); if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "could not get output tensor for convolution: " << ToString(status); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index dfe2779949..e2de3c62d8 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ +#include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/platform/mutex.h" @@ -42,7 +43,6 @@ extern const PluginId kCuDnnPlugin; class CudnnSupport : public dnn::DnnSupport { public: explicit CudnnSupport(CUDAExecutor* parent); - ~CudnnSupport() override; port::Status Init() override; port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override; @@ -624,54 +624,11 @@ class CudnnSupport : public dnn::DnnSupport { dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) override; - const Stream* GetCurrentDnnStream() const - SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) { - return current_dnn_stream_; - } - - void SetCurrentDnnStream(Stream* stream) - EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_) { - current_dnn_stream_ = stream; - } - - CUDAExecutor* GetParentExecutor() { return parent_; } - - // Guards the enqueueing of DNN operations via the dnn_handle_ below, and - // access to current_dnn_stream_. - // - // This is a public member because we need to add thread safety annotations in - // the cudnn wrapper functions in the cc file, which need to access this - // mutex (the annotations require C++ permission checks). - mutex dnn_handle_mutex_; - private: CUDAExecutor* parent_; // Parent executor object. Not owned. - // cudnn library handle. cudnnHandle_t type is not present in this header to - // prevent third-party library header inclusions from leaking outside the - // single cuda_dnn translation unit. - void* dnn_handle_ GUARDED_BY(dnn_handle_mutex_); - - // The current cudnn stream that is set by cudnnSetStream(). - Stream* current_dnn_stream_ GUARDED_BY(dnn_handle_mutex_); - - // NOTE(keveman): Temporary data layout transformation until cuDNN supports - // kBatchYXDepth for backward pass. This function allocates temporary memory, - // lays out the source data into the temporary but in the kBatchDepthXY - // layout, and returns the temporary memory. The caller is responsible for - // deallocating the temporary. Since the allocation is done using Stream's - // AllocateTemporaryMemory, a later BlockHostUntilDone could be used for - // deallocation. - // - // transform_scratch is populated with a legitimate temporary allocation iff - // the original output data needs to be transformed. - template<class T> - DeviceMemory<T> MaybeTransformLayout( - Stream* stream, - dnn::BatchDescriptor* output_descriptor, - DeviceMemory<T> backward_output_data, - std::unique_ptr<TemporaryDeviceMemory<T>>* transform_scratch) - EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_); + // Provides access to the cuDNN handle. + std::unique_ptr<class CudnnAccess> cudnn_; template <class T, class U> bool DoBatchNormalizationForwardImpl( @@ -700,7 +657,7 @@ class CudnnSupport : public dnn::DnnSupport { template <class T> bool DoConvolveImpl(Stream* stream, - const dnn::BatchDescriptor& batch_descriptor, + const dnn::BatchDescriptor& input_descriptor, const DeviceMemory<T>& input_data, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory<T>& filter_data, diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 15523028c7..eeb1fab40c 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -262,6 +262,10 @@ Status InitializeSession(int num_threads, const string& graph, tensorflow::GraphDef tensorflow_graph; Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); if (!s.ok()) { + s = ReadTextProto(Env::Default(), graph, graph_def->get()); + } + + if (!s.ok()) { LOG(ERROR) << "Could not create TensorFlow Graph: " << s; return s; } diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc index 16ab2ff66e..6813045d63 100644 --- a/tensorflow/tools/benchmark/benchmark_model_test.cc +++ b/tensorflow/tools/benchmark/benchmark_model_test.cc @@ -26,30 +26,36 @@ limitations under the License. namespace tensorflow { namespace { -TEST(BenchmarkModelTest, InitializeAndRun) { - const string dir = testing::TmpDir(); - const string filename_pb = io::JoinPath(dir, "graphdef.pb"); - +void CreateTestGraph(const ::tensorflow::Scope& root, + benchmark_model::InputLayerInfo* input, + string* output_name, GraphDef* graph_def) { // Create a simple graph and write it to filename_pb. const int input_width = 400; const int input_height = 10; - benchmark_model::InputLayerInfo input; - input.shape = TensorShape({input_width, input_height}); - input.data_type = DT_FLOAT; + input->shape = TensorShape({input_width, input_height}); + input->data_type = DT_FLOAT; const TensorShape constant_shape({input_height, input_width}); Tensor constant_tensor(DT_FLOAT, constant_shape); test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; }); - auto root = Scope::NewRootScope().ExitOnError(); auto placeholder = - ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape)); - input.name = placeholder.node()->name(); + ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape)); + input->name = placeholder.node()->name(); auto m = ops::MatMul(root, placeholder, constant_tensor); - const string output_name = m.node()->name(); + *output_name = m.node()->name(); + TF_ASSERT_OK(root.ToGraphDef(graph_def)); +} + +TEST(BenchmarkModelTest, InitializeAndRun) { + const string dir = testing::TmpDir(); + const string filename_pb = io::JoinPath(dir, "graphdef.pb"); + auto root = Scope::NewRootScope().ExitOnError(); + benchmark_model::InputLayerInfo input; + string output_name; GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + CreateTestGraph(root, &input, &output_name, &graph_def); string graph_def_serialized; graph_def.SerializeToString(&graph_def_serialized); TF_ASSERT_OK( @@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) { ASSERT_EQ(num_runs, 10); } +TEST(BenchmarkModeTest, TextProto) { + const string dir = testing::TmpDir(); + const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt"); + auto root = Scope::NewRootScope().ExitOnError(); + + benchmark_model::InputLayerInfo input; + string output_name; + GraphDef graph_def; + CreateTestGraph(root, &input, &output_name, &graph_def); + TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def)); + + std::unique_ptr<Session> session; + std::unique_ptr<GraphDef> loaded_graph_def; + TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session, + &loaded_graph_def)); + std::unique_ptr<StatSummarizer> stats; + stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get()))); + int64 time; + int64 num_runs = 0; + TF_ASSERT_OK(benchmark_model::TimeMultipleRuns( + 0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(), + &time, &num_runs)); + ASSERT_EQ(num_runs, 10); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 9ddb219048..00bfcfd49b 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -250,7 +250,7 @@ def update_md_files(old_version, new_version): # Update any links to colab notebooks. def colab_url(version): - version_string = "%d.%d.%d" % (version.major, version.minor, version.patch) + version_string = "%s.%s.%s" % (version.major, version.minor, version.patch) prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r" return prefix + version_string + "/" diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 8b6ad0a138..01d424f20b 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -228,6 +228,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef", strip_prefix = "libpng-1.6.34", build_file = clean_dep("//third_party:png.BUILD"), + patch_file = clean_dep("//third_party:png_fix_rpi.patch"), ) tf_http_archive( @@ -452,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/7b8a8728fbd27086efbf3c57cf2bb35a557108c9.tar.gz", ], - sha256 = "93895b289a78a47a1e75652e12a1b9a6c119f086a509b00e0084cf2bb944b709", - strip_prefix = "llvm-b3f6a6a61625296bb532a65c0bf51b91b05b3361", + sha256 = "c620859c3ae5818f316de4837f340b3bba1646f8add0a28e6d4da34ce47e3969", + strip_prefix = "llvm-7b8a8728fbd27086efbf3c57cf2bb35a557108c9", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) @@ -744,6 +745,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = clean_dep("//third_party:tflite_smartreply.BUILD"), ) + tf_http_archive( + name = "tflite_ovic_testdata", + sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip", + "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip", + ], + build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"), + strip_prefix = "ovic", + ) + ############################################################################## # BIND DEFINITIONS # diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl index 54d383d7d7..cfd8bfe98d 100644 --- a/third_party/clang_toolchain/download_clang.bzl +++ b/third_party/clang_toolchain/download_clang.bzl @@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder): # Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release # can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py - CLANG_REVISION = '321529' + CLANG_REVISION = '330570' CLANG_SUB_REVISION = 2 package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION) checksums = { 'Linux_x64': - '76d4eb1ad011e3127c4a9de9b9f5d4ac624b5a9395c4d7395c9e0a487b13daf6', + '2108e172e05d4904c3c46125a33ab4a1175b36ec2a2226619a243e1d8f397e97', 'Mac': - '4b2a7a65ac1ee892b318c723eec8771f514bb306f346aa8216bb0006f19d87b7', + '481b5c6909f0ea250216061bd45e9c982b4befff65cbfca2ee1090c21a109eac', 'Win': - 'eba51bb8f84af41a85903113666bd21c22709010c39c4cb19dc20cf1ed14581b', + '8f04a3ac99d463d4179eb2f68a13575408c3dddc62887a1e441c77123e35e301', } platform_folder = _get_platform_folder(repo_ctx.os.name) diff --git a/third_party/png_fix_rpi.patch b/third_party/png_fix_rpi.patch new file mode 100644 index 0000000000..80da7b3c06 --- /dev/null +++ b/third_party/png_fix_rpi.patch @@ -0,0 +1,16 @@ +diff -r -u /tmp/libpng-1.6.34/scripts/pnglibconf.h.prebuilt ./scripts/pnglibconf.h.prebuilt +--- /tmp/libpng-1.6.34/scripts/pnglibconf.h.prebuilt 2017-09-29 01:42:33.000000000 -0700 ++++ ./scripts/pnglibconf.h.prebuilt 2018-05-01 09:51:24.719318242 -0700 +@@ -20,6 +20,12 @@ + #define PNG_ALIGNED_MEMORY_SUPPORTED + /*#undef PNG_ARM_NEON_API_SUPPORTED*/ + /*#undef PNG_ARM_NEON_CHECK_SUPPORTED*/ ++ ++/* Workaround not having a great build file by forcing ++ * png filter optimization to be disabled on arm */ ++#define PNG_ARM_NEON_OPT 0 ++ ++ + /*#undef PNG_POWERPC_VSX_API_SUPPORTED*/ + /*#undef PNG_POWERPC_VSX_CHECK_SUPPORTED*/ + #define PNG_BENIGN_ERRORS_SUPPORTED diff --git a/third_party/tflite_ovic_testdata.BUILD b/third_party/tflite_ovic_testdata.BUILD new file mode 100644 index 0000000000..de47ed61f9 --- /dev/null +++ b/third_party/tflite_ovic_testdata.BUILD @@ -0,0 +1,12 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) |