diff options
author | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-05 09:40:03 -0700 |
---|---|---|
committer | avijit-nervana <avijit.chakraborty@intel.com> | 2018-09-05 09:40:03 -0700 |
commit | d972850a44ae624ad957b0dcc9d740e18f0cc10c (patch) | |
tree | bb80dd62961c20130c386d4cf62ca5cb9150240d | |
parent | a65c6c17d0705fe11be6f33f63a677106bf26ffb (diff) | |
parent | 47860208eee575119b0dd1b6168dc24cf51caf64 (diff) |
Merge branch 'master' into avijit/add-cpu-backend
252 files changed, 10421 insertions, 6541 deletions
diff --git a/CODEOWNERS b/CODEOWNERS index 1725a5c471..78f80c8d71 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -60,3 +60,5 @@ /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj /tensorflow/contrib/training/ @joel-shor @ebrevdo /tensorflow/contrib/util/ @sherrym + +/third_party/systemlibs/ @perfinion diff --git a/tensorflow/BUILD b/tensorflow/BUILD index b5e0a4e98b..661cba5ff0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -433,6 +433,7 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", + "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", "//third_party/py/tensor2tensor/...", ], diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 2c3a877edf..109b3b37aa 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -117,6 +117,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/c/eager:c_api", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6617c5a572..09d482d6df 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include <stdint.h> #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlow. @@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, TF_Tensor* tensor, TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( + const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 1ccae3f138..77e3878a94 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::IntraProcessRendezvous(device_mgr.get()); return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +} + +TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, + TF_Session* sess, TF_Status* status) { + const tensorflow::DeviceMgr* device_mgr = nullptr; + status->status = sess->session->LocalDeviceManager(&device_mgr); + if (!status->status.ok()) return nullptr; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, /*device_mgr_owned*/ false, + r); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a5c0681e2e..104d52430c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,15 +62,14 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, - bool async, - std::unique_ptr<tensorflow::DeviceMgr> device_mgr, - tensorflow::Rendezvous* rendezvous) + TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, + tensorflow::Rendezvous* rendezvous) : context(opts, static_cast<tensorflow::ContextDevicePlacementPolicy>( default_policy), - async, std::move(device_mgr), rendezvous) {} + async, device_mgr, device_mgr_owned, rendezvous) {} tensorflow::EagerContext context; }; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index b5667ca0d3..e2affee51f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -40,26 +40,11 @@ using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { -string DebugString(const CondStateMap::CondNode& node) { - return node.ToString(); -} - // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return strings::StrCat(tensor.node->name(), ":", tensor.index); } -string DebugString(CondStateMap::CondId cond_state) { - if (cond_state == nullptr || cond_state->empty()) return "[]"; - return strings::StrCat( - "[", - absl::StrJoin(*cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), - "]"); -} - string Branch_Name(BranchType b) { switch (b) { case BranchType::kElseBranch: @@ -73,6 +58,24 @@ string Branch_Name(BranchType b) { } } +string DebugString(StateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "{}"; + using value_type = StateMap::CondState::value_type; + return strings::StrCat( + "{", + absl::StrJoin(*cond_state, ", ", + [](string* output, const value_type& pred_branch) { + const OutputTensor& pred = pred_branch.first; + const BranchType& branch = pred_branch.second; + if (branch == BranchType::kNeither) + strings::StrAppend(output, "d"); + else + strings::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); + }), + "}"); +} + // Returns the predicate of a switch. Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { const Edge* pred_edge; @@ -86,64 +89,65 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { return Status::OK(); } -CondStateMap::CondNode::CondNode(Type type, Node* switch_node, - BranchType branch) - : type(type), branch(branch) { - if (type == Type::kSwitch) { - TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); - } -} - -string CondStateMap::CondNode::ToString() const { - switch (type) { - case Type::kSwitch: - return strings::StrCat("s(", DebugString(predicate), ",", - Branch_Name(branch), ")"); - case Type::kMerge: - return "m"; - case Type::kDead: - return "d"; - } +Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { + const Edge* val_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); + *val = OutputTensor(val_edge->src(), val_edge->src_output()); + return Status::OK(); } -bool CondStateMap::CondNode::operator==(const CondNode& other) const { - if (type != Type::kSwitch) return type == other.type; - return type == other.type && predicate == other.predicate && - branch == other.branch; +bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, + const OutputTensor& rhs) const { + return (lhs.node->id() < rhs.node->id()) || + (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index); } -bool CondStateMap::CondNode::operator!=(const CondNode& other) const { - return !(*this == other); -} +struct CondStateLess { + bool operator()(const StateMap::CondState::value_type& lhs, + const StateMap::CondState::value_type& rhs) const { + if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first)) + return true; + if (lhs.first.node->id() == rhs.first.node->id() && + lhs.first.index == rhs.first.index) + return lhs.second < rhs.second; + return false; + } +}; -CondStateMap::CondStateMap(Graph* graph) { +StateMap::StateMap(Graph* graph) { node_to_condid_map_.resize(graph->num_node_ids()); + node_to_ancestorid_map_.resize(graph->num_node_ids()); // Initialize the dead state (empty state is designated with a nullptr). - dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); + dead_id_ = GetCondId( + {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)}); } -bool CondStateMap::IsDead(CondStateMap::CondId id) const { - return id == dead_id_; -} +bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; } -bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { - return id == nullptr; -} +bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondNode& item) const { - return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), - hash<BranchType>()(item.branch)), - hash<CondStateMap::CondNode::Type>()(item.type)); +size_t StateMap::Hash::operator()(const StateMap::CondState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = Hash64Combine(OutputTensor::Hash()(it->first), + hash<BranchType>()(it->second)); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first), + hash<BranchType>()(it->second))); + } + return h; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondState& vec) const { - if (vec.empty()) return 0; - size_t h = (*this)(vec.front()); - auto it = vec.begin(); - for (++it; it != vec.end(); ++it) { - h = Hash64Combine(h, (*this)(*it)); +size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = hash<Node*>()(*it); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, hash<Node*>()(*it)); } return h; } @@ -176,49 +180,71 @@ string DebugString(const CondArgNodes& nodes) { "]"); } -CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { +StateMap::CondId StateMap::LookupCondId(const Node* node) const { if (node->id() < node_to_condid_map_.size()) return node_to_condid_map_[node->id()]; - return added_node_mapping_.at(node->id()); + return added_node_condid_mapping_.at(node->id()); } -CondStateMap::CondId CondStateMap::GetUniqueId( - const CondStateMap::CondState& state) { +StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { if (state.empty()) return nullptr; return &*condstate_set_.insert(state).first; } -const CondStateMap::CondState& CondStateMap::LookupState( - const Node* node) const { - return *LookupId(node); -} - -void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { +void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { if (node->id() < node_to_condid_map_.size()) node_to_condid_map_[node->id()] = id; else - added_node_mapping_[node->id()] = id; + added_node_condid_mapping_[node->id()] = id; +} + +StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { + if (node->id() < node_to_ancestorid_map_.size()) + return node_to_ancestorid_map_[node->id()]; + return added_node_ancestorid_mapping_.at(node->id()); } -void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } +StateMap::AncestorId StateMap::GetAncestorId( + const StateMap::AncestorState& state) { + if (state.empty()) return nullptr; + return &*ancestorstate_set_.insert(state).first; +} -string CondStateMap::CondStateToString(const Node* node) const { - return CondStateToString(LookupId(node)); +void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { + if (node->id() < node_to_ancestorid_map_.size()) + node_to_ancestorid_map_[node->id()] = id; + else + added_node_ancestorid_mapping_[node->id()] = id; } -string CondStateMap::CondStateToString(CondStateMap::CondId id) const { +const StateMap::CondState& StateMap::LookupState(const Node* node) const { + return *LookupCondId(node); +} + +void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } + +string StateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupCondId(node)); +} + +string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } +string StateMap::AncestorStateToString(const Node* node) const { + if (auto id = LookupAncestorId(node)) return NodesToString(*id); + return "{}"; +} + FunctionalizeCond::FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : cond_state_map_(graph), library_(library), graph_(graph) {} + : state_map_(graph), library_(library), graph_(graph) {} // Class representing the merge/switch nodes that will become a conditional. class Conditional { public: Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map); + StateMap* cond_state_map); // Adds merge node that is part of this conditional. Status AddMerge(Node* m); @@ -247,6 +273,10 @@ class Conditional { // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); + // Adds a switch node along the edge and rewire the edge to go via the switch. + Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); + // Internal name of conditional. The name is based on the first merge node // added. string name() const; @@ -255,7 +285,7 @@ class Conditional { FunctionalizeCond* parent_; // Mapping between nodes and their cond state. - CondStateMap* cond_state_map_; + StateMap* state_map_; // The predicate of the conditional. OutputTensor predicate_; @@ -292,8 +322,8 @@ class Conditional { }; Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map) - : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + StateMap* cond_state_map) + : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {} Status Conditional::AddMerge(Node* m) { merges_.insert(m); @@ -397,6 +427,35 @@ Status Conditional::BuildArgumentNodes() { return Status::OK(); } +Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph) { + // Previously we had edge: + // src:src_output ---- edge ----> dst:dst_input + // post this we have (in graph) + // src:src_output --> switch<pred> --- new_edge --> dst:dst_input + + // TODO(jpienaar): One could keep a map caching the extra switch nodes added + // to avoid adding another switch to feed a value for which a switch was + // already added. + Node* switch_node; + Node* src = edge->src(); + int src_output = edge->src_output(); + TF_RETURN_IF_ERROR( + NodeBuilder(graph->NewName(strings::StrCat(src->name(), "_added_switch")), + "Switch") + .Input(src, src_output) + .Input(const_cast<Node*>(predicate_.node), predicate_.index) + .Finalize(graph, &switch_node)); + state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src)); + state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src)); + + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input); + return AddSwitch(switch_node); +} + Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { @@ -405,16 +464,16 @@ Status Conditional::ExtractBodies(Graph* graph) { } auto find_branch = [&](const Edge* e) { - const auto& id = cond_state_map_->LookupId(e->src()); + const auto& id = state_map_->LookupCondId(e->src()); return IsSwitch(e->src()) ? BranchType(e->src_output()) - : cond_state_map_->FindBranchOf(id, predicate_); + : state_map_->FindBranchOf(id, predicate_); }; std::array<std::vector<Node*>, 2> stacks; VLOG(5) << "Merges: " << NodesToString(merges_); for (Node* m : merges_) { VLOG(5) << "For merge: " << m->DebugString() << " " - << cond_state_map_->CondStateToString(m); + << state_map_->CondStateToString(m); for (auto e : m->in_edges()) { if (e->IsControlEdge()) continue; BranchType branch = find_branch(e); @@ -422,7 +481,8 @@ Status Conditional::ExtractBodies(Graph* graph) { branch == BranchType::kElseBranch) << "Error: " << e->src()->name() << " is not on either then or else branch (" << Branch_Name(branch) - << ")."; + << ") for predicate " << DebugString(predicate_) << " [" + << DebugString(state_map_->LookupCondId(e->src())) << "]."; Node* src = e->src(); if (IsSwitch(src)) { // Switch node outputs and dependencies are handled separately. @@ -456,8 +516,8 @@ Status Conditional::ExtractBodies(Graph* graph) { if (IsMerge(dst)) continue; Node* src = e->src(); - auto dst_id = cond_state_map_->LookupId(dst); - auto src_id = cond_state_map_->LookupId(src); + auto dst_id = state_map_->LookupCondId(dst); + auto src_id = state_map_->LookupCondId(src); if (dst_id != src_id) { if (e->IsControlEdge()) { external_control_outputs_.push_back(e->src()); @@ -480,8 +540,11 @@ Status Conditional::ExtractBodies(Graph* graph) { } } - // Copying incomming edges to dst node. - for (const Edge* e : n->in_edges()) { + // Copying incomming edges to dst node. Iterate over a copy of the edges + // as they could be mutated during iteration. + std::vector<const Edge*> in_edges(n->in_edges().begin(), + n->in_edges().end()); + for (const Edge* e : in_edges) { Node* src = e->src(); // Skip src/dst node. if (!src->IsOp()) continue; @@ -494,8 +557,8 @@ Status Conditional::ExtractBodies(Graph* graph) { } // Verify input is from the same context. - auto src_id = cond_state_map_->LookupId(src); - auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = state_map_->LookupCondId(src); + auto dst_id = state_map_->LookupCondId(dst); if (IsMerge(dst) || src_id == dst_id) { // TODO(jpienaar): The merge case can be more strict. if (node_map.at(src->id()) == nullptr) { @@ -506,18 +569,25 @@ Status Conditional::ExtractBodies(Graph* graph) { external_control_inputs_.push_back(src); } else { // This shouldn't happen, this means we have an external data input - // not entering via a switch node. Work around this for constant - // nodes as some constant nodes are inserted without the required - // control context dominance. + // not entering via a switch node. Work around this by for + // * constant nodes copy them; + // * non-constant nodes, insert a switch along the edge; if (IsConstant(src)) { node_map.at(src->id()) = output->CopyNode(src); } else { - return errors::InvalidArgument( - "Graph contains node ", FormatNodeForError(*src), - " that feeds into node ", FormatNodeForError(*dst), - " but these nodes are in different control contexts (", - DebugString(src_id), " vs ", DebugString(dst_id), - " (detected during in edge testing)"); + StateMap::CondState state = *dst_id; + state.erase(predicate_); + if (state_map_->GetCondId(state) == src_id) { + TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph)); + continue; + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } } } @@ -639,7 +709,8 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build If node"; NodeDef if_def; TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + TF_ASSIGN_OR_RETURN(if_node_, + parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); return Status::OK(); } @@ -699,7 +770,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { Status Conditional::BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(1) << "Build If and replace merge nodes " << name(); + VLOG(1) << "Build If and replace merge nodes " + << NodesToString(this->merges_); if (replaced_) return Status::OK(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); @@ -719,7 +791,7 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) cond_state_map_->MarkDead(m); + for (Node* m : merges_) state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -735,55 +807,41 @@ string Conditional::name() const { return strings::StrCat((*merges_.begin())->name(), "_if"); } -bool CondStateMap::ScopeIn(CondStateMap::CondId id, - CondStateMap::CondId* scope) { - if (id == nullptr) { - *scope = nullptr; - return true; - } - CondState state; - for (const CondNode& node : *id) { - if (node.type == CondNode::Type::kSwitch) { - state.push_back(node); - } - if (node.type == CondNode::Type::kMerge) { - if (state.empty()) { - return false; - } - DCHECK(state.back().type == CondNode::Type::kSwitch && - state.back().branch == BranchType::kBoth); - state.pop_back(); - } - } - *scope = GetUniqueId(state); - return true; -} - Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, int port) { Node* id; TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") .Input(if_node, port) .Finalize(graph_, &id)); - cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); + state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); return Status::OK(); } StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def, - const Node* replacee) { + const Node* replacee, + const OutputTensor& predicate) { Status status; Node* ret = graph_->AddNode(def, &status); TF_RETURN_IF_ERROR(status); - CondStateMap::CondState state = cond_state_map_.LookupState(replacee); - state.pop_back(); VLOG(1) << "Adding If for " << replacee->name(); - cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + StateMap::CondId id = state_map_.LookupCondId(replacee); + if (id) { + StateMap::CondState state = *id; + state.erase(predicate); + state_map_.ResetCondId(ret, state_map_.GetCondId(state)); + } else { + state_map_.ResetCondId(ret, nullptr); + } + + state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee)); + return ret; } Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { VLOG(2) << "Propagating update state for " << replacee->name() << " " - << cond_state_map_.CondStateToString(replacee); + << state_map_.CondStateToString(replacee); // Redo topological sort as the order could have changed. // TODO(jpienaar): The original topological order could also be updated // dynamically if needed. @@ -801,10 +859,10 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { if (changed.find(*it) != changed.end()) { // Update the node state. Node* n = *it; - CondStateMap::CondId old_state = cond_state_map_.LookupId(n); - cond_state_map_.ResetId(n, nullptr); + StateMap::CondId old_state = state_map_.LookupCondId(n); + state_map_.ResetCondId(n, nullptr); TF_RETURN_IF_ERROR(DetermineCondState(n)); - if (cond_state_map_.LookupId(n) != old_state) { + if (state_map_.LookupCondId(n) != old_state) { for (auto out : n->out_nodes()) if (out->IsOp()) changed.insert(out); } @@ -825,127 +883,44 @@ BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { return BranchType::kNeither; } -CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - CondId lhs_scope; - CondId rhs_scope; - bool could_determine_scope = ScopeIn(lhs, &lhs_scope); - could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); - if (!could_determine_scope) return kIncomparable; - - // Returns whether a contains b. - auto contains = [&](CondId a, CondId b) { - // Handle empty states. - if (a == nullptr && b != nullptr) return true; - if (a == nullptr && b == nullptr) return true; - if (a != nullptr && b == nullptr) return false; - - if (a->size() > b->size()) return false; - auto a_it = a->begin(); - auto b_it = b->begin(); - while (a_it != a->end()) { - if (*a_it != *b_it) { - if (!(a_it->predicate == b_it->predicate)) return false; - BranchType mb = MeetBranch(a_it->branch, b_it->branch); - if (mb != b_it->branch) return false; - } - ++a_it; - ++b_it; - } - return true; - }; - - bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); - bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); - if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; - if (lhs_contains_rhs) return kLhsContainsRhs; - if (rhs_contains_lhs) return kRhsContainsLhs; - return kIncomparable; -} - -BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { +BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { if (IsEmpty(id)) return BranchType::kNeither; - absl::optional<BranchType> b; const CondState& nodes = *id; - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == predicate) { - if (b.has_value()) { - b = MeetBranch(*b, it->branch); - } else { - b = it->branch; - } - if (*b == BranchType::kNeither) { - LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); - } - } - } - return b.has_value() ? *b : BranchType::kNeither; + auto it = nodes.find(predicate); + if (it == nodes.end()) return BranchType::kNeither; + return it->second; } -StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - VLOG(4) << "Joining src=" << DebugString(src) << " [" << src +StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { + VLOG(5) << "Joining src=" << DebugString(src) << " [" << src << "] and dst=" << DebugString(dst) << " [" << dst << "]"; - if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; + if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst; // Nothing to do if the CondState is the same. if (src == dst) return src; - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); - switch (result) { - case CondStateMap::kIncomparable: - return errors::InvalidArgument( - "Graph contains node with inputs predicated on incompatible " - "predicates: ", - DebugString(src), " and ", DebugString(dst)); - case CondStateMap::kEqual: - // If both respect the same predicates, propagate the longer constraint. - if ((src != nullptr && dst == nullptr) || - (src != nullptr && dst != nullptr && src->size() > dst->size())) - return src; - else - return dst; - case CondStateMap::kLhsContainsRhs: - // src contains dst, so dst is already more restrictive. - return dst; - case CondStateMap::kRhsContainsLhs: - // dst contains src, so src is more restrictive. - return src; - } -} - -StatusOr<CondStateMap::CondState::const_iterator> -FindThenElseSwitchForPredicate(const OutputTensor& pred, - CondStateMap::CondId id) { - for (auto it = id->begin(); it != id->end(); ++it) { - // Along every path one there can be only one instance of a then or else - // switch for a given predicate, so return once found. - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == pred && - (it->branch == BranchType::kThenBranch || - it->branch == BranchType::kElseBranch)) - return it; + StateMap::CondState both = *src; + for (const auto& kv : *dst) { + auto it = both.find(kv.first); + if (it == both.end()) { + both.insert(kv); + } else { + if (it->second != kv.second) { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } + } } - return errors::Internal("Unable to find then/else branch with predicate ", - DebugString(pred), " for ", DebugString(id)); + return state_map_.GetCondId(both); } -StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { +StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge( + Node* merge, StateMap::CondId src, StateMap::CondId dst) { // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a // disjunction of the states along the different input edges. For a merge that @@ -956,91 +931,56 @@ StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge( // followed by s(p, both). VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); - if (cond_state_map_.IsEmpty(dst)) return src; - - if (cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; - - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) - << "Illegal merge inputs from outer scope: src=" << DebugString(src) - << " dst=" << DebugString(dst); - auto src_it = src_scope->begin(); - auto dst_it = dst_scope->begin(); - - // Find branch divergent condition. - OutputTensor pred; - while (src_it != src_scope->end() && dst_it != dst_scope->end()) { - if (*src_it != *dst_it) { - VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " - << DebugString(*dst_it); - if (!(src_it->predicate == dst_it->predicate)) { - return errors::InvalidArgument( - "Unable to find common predicate which holds for one input " - "but not the other of the merge node."); - } - pred = src_it->predicate; - break; - } - ++src_it; - ++dst_it; - } - - if (pred.node == nullptr) - return errors::InvalidArgument("Unable to determine predicate for merge."); - - TF_ASSIGN_OR_RETURN(auto div_src_it, - FindThenElseSwitchForPredicate(pred, src)); - TF_ASSIGN_OR_RETURN(auto div_dst_it, - FindThenElseSwitchForPredicate(pred, dst)); - TF_RET_CHECK(*div_src_it != *div_dst_it); - - CondStateMap::CondState result; - // Populate result with the longest/most restrictive path up to the divergent - // node. For example, if the one input is `[switch(pred:0, then)]` and the - // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created - // in gradient of cond test), then the resultant state here should be - // `[switch(pred:0, both), merge, switch(pred:0, both)]`. - if (std::distance(src->begin(), div_src_it) > - std::distance(dst->begin(), div_dst_it)) { - result.assign(src->begin(), std::next(div_src_it)); + if (state_map_.IsEmpty(dst)) return src; + + if (state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst)) return dst; + + std::vector<StateMap::CondState::value_type> diff; + StateMap::CondState merged; + std::set_symmetric_difference(src->begin(), src->end(), dst->begin(), + dst->end(), std::back_inserter(diff), + CondStateLess()); + std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(), + std::inserter(merged, merged.begin()), CondStateLess()); + + // Update mapping from merge node to predicate. + if (diff.size() == 2) { + auto pred = diff[0].first; + bool different_branches = (diff[0].second != diff[1].second) && + (diff[0].second == BranchType::kThenBranch || + diff[0].second == BranchType::kElseBranch) && + (diff[1].second == BranchType::kThenBranch || + diff[1].second == BranchType::kElseBranch); + if (!(pred == diff[1].first) || !different_branches) + return errors::InvalidArgument( + "Unable to determine predicate for merge node"); + merge_to_predicate_[merge] = pred; } else { - result.assign(dst->begin(), std::next(div_dst_it)); + return errors::InvalidArgument( + "Merge of two inputs that differ on more than one predicate ", + DebugString(src), " and ", DebugString(dst)); } - result.back().branch = BranchType::kBoth; - return cond_state_map_.GetUniqueId(result); + + return state_map_.GetCondId(merged); } -CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { +StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { Node* src = e->src(); - CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); - if (IsMerge(src)) { - CondStateMap::CondState state; - if (id != nullptr) state = *id; - state.emplace_back(CondStateMap::CondNode::Type::kMerge); - return cond_state_map_.GetUniqueId(state); - } + StateMap::CondId id = state_map_.LookupCondId(e->src()); + + // Dead nodes only propagate dead state. + if (state_map_.IsDead(id)) return id; + if (IsSwitch(src)) { - CondStateMap::CondState state; + StateMap::CondState state; if (id != nullptr) state = *id; - if (e->IsControlEdge()) { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType::kBoth); - } else { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType(e->src_output())); + OutputTensor predicate; + TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); + if (!e->IsControlEdge()) { + state[predicate] = BranchType(e->src_output()); } - return cond_state_map_.GetUniqueId(state); + return state_map_.GetCondId(state); } return id; } @@ -1049,22 +989,21 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) - return Status::OK(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK(); int data_inputs = 0; for (auto e : dst->in_edges()) { Node* src = e->src(); VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(src); + << state_map_.CondStateToString(src); if (!src->IsOp()) continue; if (!e->IsControlEdge()) ++data_inputs; - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } // Incomplete Merge nodes are not supported. @@ -1076,27 +1015,20 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondState(Node* dst) { - // The logic for the merge and non-merge case differ: for non-merge it is - // the most restrictive CondState, while for merge nodes the - // resultant state is less restrictive than either. - if (IsMerge(dst)) { - TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); - } else { - // Handle non-merge join. - for (auto e : dst->in_edges()) { - VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(dst); - Node* src = e->src(); - if (!src->IsOp()) continue; - - // Joining the state between the current and propagated state. - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", - FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); - } +Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " + << state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } return Status::OK(); } @@ -1104,8 +1036,7 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) - return Status::OK(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1113,8 +1044,8 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { Node* src = e->src(); // Handle merge with dead state. - const auto& src_id = cond_state_map_.LookupId(src); - if (!cond_state_map_.IsDead(src_id)) { + const auto& src_id = state_map_.LookupCondId(src); + if (!state_map_.IsDead(src_id)) { non_dead_edge = e; break; } @@ -1124,7 +1055,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } - cond_state_map_.MarkDead(node); + state_map_.MarkDead(node); delete_nodes_.push_back(node->id()); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { @@ -1149,16 +1080,33 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // along one. The checking of predicate is based on the exact predicate // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. + StateMap::CondId dst_id = state_map_.LookupCondId(node); + if (state_map_.IsDead(dst_id)) return Status::OK(); + + BranchType b; OutputTensor pred; TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); - auto dst_id = cond_state_map_.LookupId(node); - BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is - // true/false. - if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return Status::OK(); + // true/false. Consider both the data and predicate to determine if the + // node is redundant (skipping over identity node). + b = state_map_.FindBranchOf(dst_id, pred); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) { + OutputTensor val; + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + while (IsIdentity(val.node)) { + TF_RETURN_IF_ERROR(val.node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + } + b = state_map_.FindBranchOf(dst_id, val); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + } - VLOG(5) << "Redundant switch " << node->name(); + VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " + << DebugString(dst_id); const Edge* value_edge; TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); Node* val_node = value_edge->src(); @@ -1171,19 +1119,19 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { graph_->RemoveEdge(e); if (switch_branch == Graph::kControlSlot) { if (IsMerge(dst_node)) { - auto id_or = - JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + auto id_or = JoinCondStatesMerge(dst_node, dst_id, + state_map_.LookupCondId(dst_node)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst_node)); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } else { auto id_or = - JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node)); TF_RETURN_IF_ERROR(id_or.status()); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } } else if (BranchType(switch_branch) != b) { - cond_state_map_.MarkDead(dst_node); + state_map_.MarkDead(dst_node); delete_nodes_.push_back(dst_node->id()); continue; } @@ -1195,20 +1143,47 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondStates( - std::vector<Node*> rev_topo_order) { +Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) { // The state that is propagated along the given edge. for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { Node* dst = *it; TF_RETURN_IF_ERROR(DetermineCondState(dst)); + TF_RETURN_IF_ERROR(DetermineAncestorState(dst)); if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); - VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) + << " @ " << state_map_.AncestorStateToString(dst); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it"); } return Status::OK(); } +Status FunctionalizeCond::DetermineAncestorState(Node* dst) { + StateMap::AncestorId id = nullptr; + StateMap::AncestorState state; + + auto insert = [&](StateMap::AncestorId id, Node* src) { + auto other_id = state_map_.LookupAncestorId(src); + if (other_id != id && other_id != nullptr) { + state.insert(other_id->begin(), other_id->end()); + } + if (IsSwitch(src) || IsMerge(src)) { + state.insert(src); + } + return state_map_.GetAncestorId(state); + }; + + // Compute the union of all the switch/merge nodes that affects the input of + // dst. + for (auto e : dst->in_edges()) { + Node* src = e->src(); + id = insert(id, src); + } + state_map_.ResetAncestorId(dst, id); + return Status::OK(); +} + void FunctionalizeCond::DeleteReachableNodes() { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been @@ -1239,16 +1214,8 @@ void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) { inner_to_outer_merge_order.reserve(merge_order->size()); for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { Node* merge = *it; - CondStateMap::CondId id = cond_state_map_.LookupId(merge); - int depth = 0; - for (auto cond_node_it = id->begin(); cond_node_it != id->end(); - ++cond_node_it) { - if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && - (cond_node_it->branch == BranchType::kThenBranch || - cond_node_it->branch == BranchType::kElseBranch)) { - ++depth; - } - } + StateMap::CondId id = state_map_.LookupCondId(merge); + int depth = id != nullptr ? id->size() : 0; inner_to_outer_merge_order.emplace_back(depth, merge); } std::stable_sort( @@ -1271,10 +1238,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // determine deeper equivalence). We shall refer to this structure as the // CondState; // 3. Sort the merge nodes by nesting depth; - // 4. Extract merge nodes together that have the same CondState and whose - // input nodes have the same state from the innermost to the outermost into - // IfOps; Note: In the above only nodes paths that converge to a merge node - // will be considered for removal. + // 4. Extract merge nodes together that have the same CondState and + // AncestorState from the innermost to the outermost into IfOps; + // Note: In the above only nodes that feed into a merge node will be + // considered for functionalization. // Perform a DFS over the graph and // * Determine the reverse topological order of the nodes (there should be no @@ -1306,40 +1273,40 @@ Status FunctionalizeCond::FunctionalizeInternal() { return Status::OK(); } - TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); - + TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Extract from innermost out. - for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { - Node* merge = *it; - auto id = cond_state_map_.LookupId(merge); - if (cond_state_map_.IsDead(id)) continue; - - // Construct a Conditional with the predicate of the merge (which is the - // last entry of the CondState for the merge) and this as parent. - DCHECK(id->back().predicate.node != nullptr); - Conditional cond(id->back().predicate, this, &cond_state_map_); - TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - - // Find all merge nodes with the same CondId. This is done repeatedly as - // the CondId can change due replaced conditionals. E.g., the one branch - // could previously have had a conditional nested in it, and so would have - // had CondState with sub-state [switch(p,b),m] (where p is some predicate), - // post removing the nested conditional that sub-state would no longer be - // path of the propagated state along that path. - auto end = merge_order.end(); - for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; - ++merge_candidate_it) { - auto merge_candidate_it_id = - cond_state_map_.LookupId(*merge_candidate_it); - if (merge_candidate_it_id != id) continue; - TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + // Cluster merge nodes by CondId and AncestorId in order of nesting. + using ClusterPair = std::pair<StateMap::CondId, StateMap::AncestorId>; + std::deque<std::vector<Node*>> merge_clusters; + std::map<ClusterPair, int> merge_cluster_index; + for (Node* merge : merge_order) { + auto cond_id = state_map_.LookupCondId(merge); + if (state_map_.IsDead(cond_id)) continue; + + ClusterPair key = + std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto idx = merge_cluster_index.find(key); + if (idx == merge_cluster_index.end()) { + merge_cluster_index[key] = merge_clusters.size(); + merge_clusters.push_back({merge}); + } else { + merge_clusters[idx->second].emplace_back(merge); } + } + // Extract the conditionals from inner most to outer most. Extracting from + // innermost to outermost enables the extraction pass to stop once it + // encounters a Switch node instead of having to keep track of Switch/Merge + // nodes seen. + for (const auto& cluster : merge_clusters) { + // Construct a Conditional with the predicate of the merge. + Conditional cond(merge_to_predicate_.at(cluster.front()), this, + &state_map_); + for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); @@ -1359,7 +1326,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { for (Node* n : graph_->nodes()) { n->ClearAttr(kCondGroupDebugAttr); - n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + n->AddAttr(kCondGroupDebugAttr, + strings::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " << dump_graph::DumpGraphToFile( diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 86436011c6..28301150ea 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,105 +43,88 @@ enum class BranchType { kNeither = 3, }; -// CondStateMap is responsible for mapping from each graph Node to a CondState, -// where each CondState is the array of CondNodes (corresponding to switch, -// merge or dead states) as described below. For efficiency, this class interns -// the CondState, so that CondState equality comparisons are simply pointer +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer // comparisons. -class CondStateMap { +class StateMap { public: - explicit CondStateMap(Graph* graph); - - // Represents an entry in the CondState. An entry can either be the - // switch (along with predicate), merge, or dead: - // * switch node indicates a node that is executed along a branch with the - // given predicate - a branch can be then, else or both; - // * merge node indicates that the node is executed as output of a merge; - // * dead indicates that this node can never be executed; - struct CondNode { - enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; - - CondNode(Type type, Node* switch_node = nullptr, - BranchType branch = BranchType::kNeither); - - string ToString() const; - bool operator==(const CondNode& other) const; - bool operator!=(const CondNode& other) const; - - // Type of node. - Type type; - - // Predicate and branch, only used when type is kSwitch. - OutputTensor predicate; - BranchType branch; + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; }; - // A node in the graph is executed when multiple conditions hold. The order - // represents the nesting of the predicates that hold and is used when - // extracting the nested conditionals. - using CondState = std::vector<CondNode>; + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>; // Every unique ID is mapped to a CondState. using CondId = const CondState*; + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set<Node*>; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + // Returns the CondId for a given node. - CondId LookupId(const Node* node) const; + CondId LookupCondId(const Node* node) const; // Returns the unique CondId for CondState. - CondId GetUniqueId(const CondState& state); + CondId GetCondId(const CondState& state); + + // Resets the CondId for a given node. + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); // Returns the CondState for a Node. // REQUIRES: node has a non-empty CondState. const CondState& LookupState(const Node* node) const; - // Resets the CondId for a given node. - void ResetId(const Node* node, CondId id); - // Marks `node` as dead. void MarkDead(const Node* node); // Determine branch execution of CondState. BranchType FindBranchOf(CondId id, OutputTensor predicate) const; - // Enum to represent whether one cond flow state contains another. - enum ContainsResult { - kIncomparable, - kEqual, - kLhsContainsRhs, - kRhsContainsLhs - }; - - // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., - // [(p,t)] contains [(p,t), (r,t)]. - ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); - // Returns textual representation of node's CondState. string CondStateToString(const Node* node) const; string CondStateToString(CondId id) const; + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; // Returns whether the cond state is the empty state. bool IsEmpty(CondId id) const; - // Computes the predicates that have to hold for a node to execute and returns - // whether it was possible to determine the predicates that must hold. `scope` - // is populated with these predicates. Scope differs from state in that it - // does not include merge and both nodes. - bool ScopeIn(CondId id, CondId* scope); - private: - // Hash for CondNode and CondState. - struct CondHash { - size_t operator()(const CondNode& item) const; - size_t operator()(const CondState& vec) const; + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; }; // Set to keep track of unique CondStates. // Pointers to the entries in the unordered set are used as identifiers: // unordered_set guarantees that the pointers remain the same. - std::unordered_set<CondState, CondHash> condstate_set_; + std::unordered_set<CondState, Hash> condstate_set_; // Mapping from Node id to CondId. std::vector<CondId> node_to_condid_map_; @@ -150,7 +133,12 @@ class CondStateMap { // from Node id in the original graph to the CondId, but there will be nodes // added to the original graph (such as If nodes) whose CondState needs to be // tracked too. - std::unordered_map<int, CondId> added_node_mapping_; + std::unordered_map<int, CondId> added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set<AncestorState, Hash> ancestorstate_set_; + std::vector<AncestorId> node_to_ancestorid_map_; + std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_; // Identifier of the dead flow state. The empty flow state is represented with // a nullptr. @@ -173,7 +161,8 @@ class FunctionalizeCond { // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. - xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee); + xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); // Propagates the state of a newly inserted node. Status PropagateUpdatedState(const Node* replacee); @@ -185,35 +174,42 @@ class FunctionalizeCond { FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); // Performs the actual cond functionalization. Iterate over groups of merge - // nodes (linked by common predicate & CondIds of the incomming edges), - // from innermost to outermost, and extract into If nodes. + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. Status FunctionalizeInternal(); // Returns the forward flow state propagated along edge `e`. - // This may modify cond_state_map_. - CondStateMap::CondId StateAlongEdge(const Edge* e); + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); - // Determines the CondState of all the nodes in the given vector where - // the input is expected in reverse topological order. - // This populates the cond_state_map_. - Status DetermineCondStates(std::vector<Node*> rev_topo_order); + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + Status DetermineStates(std::vector<Node*> rev_topo_order); // Determine the CondState for a given node using the incomming edges // to the node. Note: it is expected that this node's CondState is only // determined once its input's CondState is. - Status DetermineCondState(Node* dst); + Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } // Helper functions for DetermineCondState. + Status DetermineCondStateNonMerge(Node* dst); Status DetermineCondStateMerge(Node* dst); - // Helper functions for DetermineCondStates. Determines the dst node's - // CondState by joining the src and dst's CondState where either - // the dst node is a merge or not. - // These may modify cond_state_map_. - xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); - xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + Status DetermineAncestorState(Node* dst); // Checks if a merge node is redundant and if so removes it from the graph. Status RemoveRedundantMerge(Node* node); @@ -228,9 +224,13 @@ class FunctionalizeCond { // Deletes all nodes in/consumers of `delete_nodes_`. void DeleteReachableNodes(); - // Member used to unique the CondState to a unique CondId and keep track of - // CondState/CondId per Node. - CondStateMap cond_state_map_; + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; + + // Mapping from merge nodes to predicate. + std::unordered_map<Node*, OutputTensor> merge_to_predicate_; // Nodes to be deleted. std::deque<int> delete_nodes_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index a27f889392..b0aabd63bb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -37,28 +37,23 @@ class FunctionalizeCondTest : public ::testing::Test { flib_def_.get())); } - CondStateMap::CondId GetUniqueId( - const CondStateMap::CondStateMap::CondState& state) { - return fc_->cond_state_map_.GetUniqueId(state); + StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { + return fc_->state_map_.GetCondId(state); } - xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesNonMerge(src, dst); - } - - xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesMerge(src, dst); + string GetString(const StateMap::StateMap::CondId id) { + return fc_->state_map_.CondStateToString(id); } - bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { - return fc_->cond_state_map_.ScopeIn(ff, scope); + xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); } - CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesMerge(n, src, dst); } FunctionDefLibrary fdef_lib_; @@ -69,50 +64,6 @@ class FunctionalizeCondTest : public ::testing::Test { namespace { -TEST_F(FunctionalizeCondTest, ScopeIn) { - Tensor pred_tensor(DT_BOOL, TensorShape()); - pred_tensor.flat<bool>().setZero(); - Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); - Tensor val_tensor(DT_INT32, TensorShape()); - val_tensor.flat<int>().setZero(); - Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); - - { - CondStateMap::CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope; - ASSERT_TRUE(ScopeIn(id, &scope)); - ASSERT_TRUE(id == scope); - } - - CondStateMap::CondState empty; - { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope_1; - ASSERT_TRUE(ScopeIn(id, &scope_1)); - ASSERT_TRUE(scope_1 == GetUniqueId(empty)); - ASSERT_TRUE(id != scope_1); - - ss.clear(); - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - id = GetUniqueId(ss); - CondStateMap::CondId scope_2; - ASSERT_TRUE(ScopeIn(id, &scope_2)); - - ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == - CondStateMap::ContainsResult::kLhsContainsRhs); - } -} - TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor pred_tensor(DT_BOOL, TensorShape()); pred_tensor.flat<bool>().setZero(); @@ -120,22 +71,18 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor val_tensor(DT_INT32, TensorShape()); val_tensor.flat<int>().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); + Node* m = test::graph::Merge(graph_.get(), val, val); - CondStateMap::CondId empty = GetUniqueId({}); - - CondStateMap::CondId then_branch; + StateMap::CondId then_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch)); then_branch = GetUniqueId(ss); } - CondStateMap::CondId else_branch; + StateMap::CondId else_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch)); else_branch = GetUniqueId(ss); } @@ -144,39 +91,14 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { EXPECT_TRUE(errors::IsInvalidArgument(status)); // Merge between then and else branch. - auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); TF_EXPECT_OK(joined_or.status()); - CondStateMap::CondId joined = joined_or.ValueOrDie(); + StateMap::CondId joined = joined_or.ValueOrDie(); // Merge between then branch and both branch. auto t = JoinCondStatesNonMerge(then_branch, joined); // Note: this is OK in terms of constraint predication, but TF_EXPECT_OK(t.status()); - - // Post merge the propagated forward flow state has an additional merge. - CondStateMap::CondId post_merge; - { - CondStateMap::CondState ss; - ss = *joined; - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - post_merge = GetUniqueId(ss); - } - - t = JoinCondStatesNonMerge(post_merge, joined); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(joined == t.ValueOrDie()); - - // No predicate that results in two paths predicated on different conditions - // merge. - t = JoinCondStatesMerge(post_merge, joined); - EXPECT_FALSE(t.ok()); - - // Post the merge we are effectively in the root scope and merging should - // result in the more restrictive post merge state. - t = JoinCondStatesNonMerge(post_merge, empty); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(post_merge == t.ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e639028ccd..7f2125f74c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -990,8 +990,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, instr.window(), - dimension_numbers, feature_group_count)); + lhs_shape, rhs_shape, feature_group_count, + instr.window(), dimension_numbers)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a4854f593f..8a05d1b0d7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -564,18 +564,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = - ShapeInference::InferConvolveShape(lhs_literal->shape(), - rhs_literal->shape(), window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = ShapeInference::InferConvolveShape( + lhs_literal->shape(), rhs_literal->shape(), + /*feature_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 26b48cf419..f6cfac6537 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3289,6 +3289,8 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 7c078f07d7..3d18fe3be2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -950,9 +950,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); - new_dot->set_precision_config(dot.precision_config()); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1053,9 +1053,9 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); - memoized_inst->set_precision_config(dot->precision_config()); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); - new_dot->set_precision_config(dot->precision_config()); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - dot->set_precision_config(convolution->precision_config()); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 43a891e4fa..019840b476 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1013,6 +1013,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1044,7 +1051,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_window_reversal(false); // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(builder.Build()); HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -2260,9 +2268,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - window, dnums) + /*feature_group_count=*/1, window, + dnums) .ValueOrDie(), - lhs_pad, filter, window, dnums)); + lhs_pad, filter, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2368,9 +2378,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { .ValueOrDie(); auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - window, dnums) + /*feature_group_count=*/1, window, + dnums) .ValueOrDie(), - input, rhs_pad, window, dnums)); + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. @@ -2522,8 +2534,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); - b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, - window, dnums)); + b.AddInstruction(HloInstruction::CreateConvolve( + out_shape, input, filter, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = HloTestBase::CreateNewModule(); @@ -2901,7 +2914,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, + DefaultPrecisionConfig(2))); std::unique_ptr<HloComputation> dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -3253,8 +3267,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -3329,8 +3343,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3393,8 +3407,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3511,8 +3525,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 dot_row_size = 1; int64 dot_col_size = spec.n; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3581,8 +3595,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 dot_row_size = spec.m; int64 dot_col_size = 1; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index a16b85a0a5..eda026ac56 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); - new_dot->set_precision_config(batch_dot->precision_config()); + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index b08705d4c2..d480d72297 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -308,8 +308,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8bd1533972..7398f105a0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1490,10 +1490,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_ab = builder.AddInstruction( - HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); - auto dot_bc = builder.AddInstruction( - HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( + shape_2x4, param_a, param_b, dot_dnums, precision_config)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( + shape_3x4, param_b, param_c, dot_dnums, precision_config)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 9c81a86bbb..0826380f65 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { filter_mask, expanded_filter, zero_filter)); auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - convolution->window(), dim_numbers, /*feature_group_count=*/1); - new_convolution->set_precision_config(convolution->precision_config()); + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 098ce17a56..2d9978404c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,9 +130,9 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); - new_conv->set_precision_config(hlo->precision_config()); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696d..616c453750 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase { static constexpr int kOutputFeatureCount = 64; }; +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. @@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 284929ca07..6bd0a2dd90 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -38,7 +38,11 @@ std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 09cb10d6ee..b2ba261790 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( - dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); - dot_r2->set_precision_config(dot->precision_config()); + auto dot_r2 = computation->AddInstruction( + HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, + dot_dnums, dot->precision_config())); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 46c23db465..9b46bfc098 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -95,6 +95,13 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -107,12 +114,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -135,12 +142,12 @@ TEST_F(CudnnConvolutionRewriterTest, Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(3); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -170,7 +177,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -200,7 +208,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -228,7 +237,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -272,13 +282,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -319,11 +330,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - conv_window, + /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, conv_window, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -350,12 +361,13 @@ TEST_F(CudnnConvolutionRewriterTest, 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - default_conv_window_, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -402,13 +414,15 @@ TEST_F(CudnnConvolutionRewriterTest, } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -449,13 +463,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -502,13 +518,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); const HloComputation* entry_computation = @@ -554,13 +572,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index a2be89511b..0a49d85c6d 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,8 +112,11 @@ std::unique_ptr<HloModule> MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 5f85f14565..576c5ff7a4 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -353,6 +353,13 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) { (neg_buffer == output_buffer_1)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( @@ -366,8 +373,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -402,8 +409,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -440,10 +447,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -481,10 +488,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index f7ed1b0316..a2c1ce34c6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 19ffb465c0..a6ae0337a5 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -61,15 +61,18 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand, } StatusOr<HloInstruction*> MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), - window, dimension_numbers)); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, + window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, window, dimension_numbers)); + convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config)); } StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand, @@ -164,15 +167,17 @@ StatusOr<HloInstruction*> MakeConcatHlo( HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } -StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers) { +StatusOr<HloInstruction*> MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); - return computation->AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); + return computation->AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dim_numbers, precision_config)); } StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index a1c4b374d1..1c82956907 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -48,8 +48,9 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand, // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr<HloInstruction*> MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -97,8 +98,10 @@ StatusOr<HloInstruction*> MakeConcatHlo( // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). -StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers); +StatusOr<HloInstruction*> MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index d1a96c10f8..62eea2b06c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2334,8 +2334,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 8b2846e0c2..113fd18eae 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) { CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash), + decltype(equal)> + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 633109249a..56b557d7ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -69,6 +69,11 @@ class HloDomainMap { // instruction is not found within any domain. int64 GetDomainId(HloInstruction* instruction) const; + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(HloInstruction* instruction) const; + private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. @@ -109,9 +114,14 @@ class HloDomainMap { const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, const InstructionOrderMap& instructions_order); + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); + string domain_kind_; std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_; tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_; + tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 6c142ee474..302807f816 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -72,6 +72,9 @@ class DomainMetadata { // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 974ab94467..43e74d2f6f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata { static absl::string_view KindName() { return "opname"; } + size_t Hash() const override { return std::hash<string>()(opname_); } + private: string opname_; }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 441dcad000..ffb3451164 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -53,7 +53,6 @@ namespace xla { namespace { - template <typename OperandT> StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, @@ -345,7 +344,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp( } StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr<HloInstruction> lhs_instr = HloInstruction::CreateConstant(lhs.CloneToUnique()); @@ -358,7 +358,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp( std::unique_ptr<HloInstruction> cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), - dim_numbers); + dim_numbers, precision_config); return Evaluate(cloned_instruction.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c2d49e56ac..e13af8e999 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -115,7 +115,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { HloOpcode opcode, const Literal& operand); StatusOr<std::unique_ptr<Literal>> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, const Literal& lhs, const Literal& rhs); protected: diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 7e490d7f32..f586f253da 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -622,6 +622,13 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); @@ -649,7 +656,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -694,7 +702,8 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -737,7 +746,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -788,9 +798,10 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { dnums.set_kernel_input_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -842,9 +853,10 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -925,9 +937,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -1002,9 +1015,10 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -1061,9 +1075,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -1124,9 +1139,10 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -1195,9 +1211,10 @@ TEST_P(HloEvaluatorTest, ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr<Literal> result = Evaluate(); @@ -1219,6 +1236,67 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } +TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { + HloComputation::Builder b(TestName()); + std::vector<int64> input_dims = {1, 2, 2, 4}; + std::vector<int64> filter_dims = {2, 2, 2, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), -7); + auto input_r1 = LiteralUtil::CreateR1<float>(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); + + std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), -31); + auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, + /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + module().AddEntryComputation(b.Build()); + + std::unique_ptr<Literal> result = Evaluate(); + + Array4D<float> expected_array(1, 1, 1, 8); + expected_array.FillWithYX( + Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}})); + auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); +} + class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index cb27e13e99..6a09bb08f4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1021,9 +1021,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1046,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data<ReturnT>(); auto rhs_literal_data = rhs_literal.data<ReturnT>(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](absl::Span<const int64> out_index) { + rhs_literal_data, + feature_group_count](absl::Span<const int64> out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1060,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_z_dim = dnums.output_feature_dimension(); const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); ElementwiseT result_val = static_cast<ElementwiseT>(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1068,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { + int64 rhs_iz = iz; + // Handle grouped convolutions. + if (feature_group_count > 1) { + // The size of a feature group. + int64 feature_group_size = z_size / feature_group_count; + rhs_iz = iz % feature_group_size; + + // The output feature dimension is a concatenation of convolution + // results from the different groups. + int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current input feature + // index belongs. + int64 input_group_index = iz / feature_group_size; + + // Calculate the group index to which the current output index + // belongs. + int64 output_group_index = + out_index[output_z_dim] / output_feature_group_size; + if (input_group_index != output_group_index) { + // If the current output index does not belong to the current + // feature group, skip it. + continue; + } + } + int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1076,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rhs_linear_index = 0; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 3041d94fa9..0345a2a5f8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -120,12 +120,19 @@ class NodeFilter { std::function<NodeFilterResult(const HloInstruction* instr)> filter_; }; +// We arbitrarily set this as the boundary between "large" and "small" +// instructions. +bool IsSmall(const HloInstruction* instr) { + return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; +} + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, kBrown, kDarkBlue, kDarkGreen, + kDarkOrange, kDarkRed, kGray, kGreen, @@ -158,6 +165,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) { return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; case kDarkGreen: return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkOrange: + // This is more of a "medium" orange, made to look close to kOrange; + // there's probably room for a darker weight if desired. + return NodeColors{"filled", "#ffb74d", "#c88719", "black"}; case kDarkRed: return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; case kGray: @@ -893,7 +904,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { sharding_colors_.emplace(instr->sharding(), color); return color; } - const auto kParameterColor = kOrange; + + // Choose different weights of orange for small vs large parameters. This + // distinction is often important, especially in fusion nodes. + auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange; // Special case: If this instruction has a parameter merged into it, paint it // the same color as a parameter. Unless the merged-in parameter is a @@ -905,7 +919,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { ShouldMergeIntoUsers(operand) && TryGetFusionParameterConstant(operand) == nullptr; })) { - return kParameterColor; + return parameter_color; } // Pick different colors or shapes for instructions which are particularly @@ -1015,7 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kParameterColor; + return parameter_color; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1160,20 +1174,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return StrJoin(lines, "<br/>"); } -// Gets the total number of array elements in the given shape. For tuples, this -// is the sum of all the sizes of all of the array elements recursively in the -// tuple. -static int64 TotalElementsInShape(const Shape& shape) { - int64 elems = 0; - ShapeUtil::ForEachSubshape( - shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { - elems += ShapeUtil::ElementsIn(subshape); - } - }); - return elems; -} - void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1196,14 +1196,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } // We print "small" arrays using a hollow arrowhead and "large" arrays using - // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" - // means. - bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - + // a filled arrowhead. constexpr char kEdgeFmt[] = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), + (IsSmall(from) ? "empty" : "normal"), from->name(), to->name(), edge_label)); }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6d13f85cbb..f25761ac70 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -341,17 +341,21 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( source_target_pairs); break; } - case HloOpcode::kConvolution: + case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + PrecisionConfigProto precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfigProto::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), proto.window(), - proto.convolution_dimension_numbers(), - std::max(static_cast<int64>(proto.feature_group_count()), 1LL)); + proto.shape(), operands(0), operands(1), + std::max<int64>(proto.feature_group_count(), 1), proto.window(), + proto.convolution_dimension_numbers(), precision_config); break; + } case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) << "ReduceWindow instruction should have 2 operands but sees " @@ -468,6 +472,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } + if (instruction->opcode() == HloOpcode::kDot) { + instruction->precision_config_ = proto.precision_config(); + instruction->precision_config_.mutable_operand_precision()->Resize( + instruction->operand_count(), PrecisionConfigProto::DEFAULT); + TF_RET_CHECK(proto.has_dot_dimension_numbers()); + instruction->dot_dimension_numbers_ = + absl::make_unique<DotDimensionNumbers>( + proto.dot_dimension_numbers()); + } else { + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) + << instruction->opcode(); + } break; } } @@ -476,12 +494,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - instruction->precision_config_ = proto.precision_config(); - - if (proto.has_dot_dimension_numbers()) { - instruction->dot_dimension_numbers_ = - absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers()); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -643,10 +655,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { return absl::make_unique<HloConvolutionInstruction>( - shape, lhs, rhs, window, dimension_numbers, feature_group_count); + shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft( @@ -658,13 +672,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers) { + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { auto instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = absl::make_unique<DotDimensionNumbers>(dimension_numbers); + instruction->set_precision_config(precision_config); return instruction; } @@ -1057,7 +1073,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); - derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1278,7 +1293,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_); + *dot_dimension_numbers_, precision_config()); break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); @@ -2167,7 +2182,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - *proto.mutable_precision_config() = precision_config_; + if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) { + *proto.mutable_precision_config() = precision_config_; + } if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); @@ -2948,7 +2965,11 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { } string HloInstruction::PrecisionConfigToString() const { - if (precision_config_.operand_precision().empty()) { + if (absl::c_all_of( + precision_config_.operand_precision(), [](int32 precision) { + return static_cast<PrecisionConfigProto::Precision>(precision) == + PrecisionConfigProto::DEFAULT; + })) { return ""; } return StrCat( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cca134e8b4..55d592ff94 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -405,9 +405,9 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr<HloInstruction> CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const PrecisionConfigProto& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr<HloInstruction> CreateFft( @@ -418,7 +418,8 @@ class HloInstruction { // dimensions specified in 'dimension_numbers'. static std::unique_ptr<HloInstruction> CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config); // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b0e940a6..b4e302e832 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1122,6 +1122,13 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { } } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { // Fused expression: // @@ -1147,8 +1154,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1188,8 +1195,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1239,8 +1246,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2))); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); auto add_operand = builder.AddInstruction( @@ -1320,8 +1327,8 @@ TEST_F(HloInstructionTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().set_print_metadata(false); @@ -1485,8 +1492,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().Canonical(); @@ -1527,8 +1534,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1583,8 +1590,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e46afa764f..e3683aaec9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1628,12 +1628,13 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), + feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers), - feature_group_count_(feature_group_count) { + convolution_dimension_numbers_(dimension_numbers) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1642,6 +1643,7 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); + set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1672,7 +1674,9 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); - extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } return extra; } @@ -1697,8 +1701,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique<HloConvolutionInstruction>( - shape, new_operands[0], new_operands[1], window(), - convolution_dimension_numbers_, feature_group_count_); + shape, new_operands[0], new_operands[1], feature_group_count_, window(), + convolution_dimension_numbers_, precision_config()); } HloReduceWindowInstruction::HloReduceWindowInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 3230383579..1c85aa4681 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -942,9 +942,9 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + const PrecisionConfigProto& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -972,12 +972,13 @@ class HloConvolutionInstruction : public HloInstruction { std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const override; - Window window_; - // Describes the dimension numbers used for a convolution. - ConvolutionDimensionNumbers convolution_dimension_numbers_; // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // Describes the window used for a convolution. + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; }; class HloReduceWindowInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ea8e6a239a..62f01c4adb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -530,10 +530,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; - optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; - attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, - &operand_precision}; - HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -913,6 +909,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -923,9 +922,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + PrecisionConfigProto precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfigProto::DEFAULT); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, - feature_group_count.value())); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], + feature_group_count.value(), *window, *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -1272,6 +1279,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<std::vector<tensorflow::int64>> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; + optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1296,8 +1306,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - instruction = builder->AddInstruction( - HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + PrecisionConfigProto precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfigProto::DEFAULT); + } + + instruction = builder->AddInstruction(HloInstruction::CreateDot( + shape, operands[0], operands[1], dnum, precision_config)); break; } case HloOpcode::kGather: { @@ -1414,12 +1433,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } - if (operand_precision) { - PrecisionConfigProto precision_config; - *precision_config.mutable_operand_precision() = {operand_precision->begin(), - operand_precision->end()}; - instruction->set_precision_config(precision_config); - } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 759789437c..0dfc0a4d1c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default} } )" @@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } )" @@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f } )" @@ -1775,5 +1777,18 @@ TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); } +TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { + const string text = + R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Convolution(op::Parameter(0), op::Parameter(1))); + auto* convolution = + Cast<HloConvolutionInstruction>(computation->root_instruction()); + EXPECT_EQ(convolution->feature_group_count(), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 34cba6136f..e3f4a9852a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast<size_t>(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index cba5db927a..e3ae82a070 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata { bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 95516dec74..069586a738 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers(), - convolution->feature_group_count())); + convolution->feature_group_count(), convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index a4de02a890..4a71ee909b 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -165,6 +165,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + instr->precision_config(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else { @@ -1030,6 +1031,7 @@ bool CanFoldDotIntoIndexedArray( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); @@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, lhs->literal(), *rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting LHS // dimension "went". @@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + const PrecisionConfigProto& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( new_dim_numbers.set_rhs_contracting_dimensions( 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, *lhs->literal(), rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting RHS // dimension "went". @@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( } StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, - Array* rhs) { + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant @@ -1119,6 +1124,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( dynamic_cast<ScalarIndexedConstantArray*>(lhs)) { if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) { return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + precision_config, lhs_indexed_array, rhs_constant); } } @@ -1126,7 +1132,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( if (auto* rhs_indexed_array = dynamic_cast<ScalarIndexedConstantArray*>(rhs)) { if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) { - return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, + precision_config, lhs_constant, rhs_indexed_array); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index dcfb725535..f21e784a4d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -267,15 +267,17 @@ class IndexedArrayAnalysis { StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, ScalarIndexedConstantArray* lhs, ConstantArray* rhs); StatusOr<Array*> ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + const PrecisionConfigProto& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); - StatusOr<Array*> ComputeArrayForDot(const Shape& shape, - const DotDimensionNumbers& dim_numbers, - Array* lhs, Array* rhs); + StatusOr<Array*> ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 021fe630ff..69c7e42601 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -874,18 +874,18 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto slice = FindInstruction(module.get(), "slice0"); - EXPECT_EQ(slice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, op::Add(op::Parameter(), + op::Slice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -902,18 +902,20 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto dslice = FindInstruction(module.get(), "dslice0"); - EXPECT_EQ(dslice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -931,18 +933,20 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto concat = FindInstruction(module.get(), "concat0"); - EXPECT_EQ(concat->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::Concatenate(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, @@ -960,15 +964,39 @@ TEST_F(LayoutAssignmentTest, )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { + const char* module_str = R"( + HloModule PropagatingLayoutFromResultToOperand + + ENTRY PropagatingLayoutFromResultToOperand { + par0 = f32[4,5]{1,0} parameter(0) + ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]} + } + )"; - auto copy = FindInstruction(module.get(), "copy.1"); - EXPECT_EQ(copy, nullptr); + auto module = ParseHloString(module_str).ValueOrDie(); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); + EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), + op::ShapeWithLayout(shape_copy)))); } } // namespace diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2611749862..74bdf2a2e3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + if (kernel_output_features % feature_group_count > 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "feature_group_count (value %d); " + "got <conv>(%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } std::vector<int64> window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index a28345acef..96a0ee165d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -108,9 +108,9 @@ class ShapeInference { // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr<Shape> InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index cc92e58ef8..864ed43118 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 530f40e4b2..7c1f4b5cc6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { } std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot( - dot->shape(), new_lhs, new_rhs, new_dim_numbers); - new_dot->set_precision_config(dot->precision_config()); + dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - new_conv->set_precision_config(convolution.precision_config()); + convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), + convolution.window(), new_dnums, convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 58f767e913..e486a00e53 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -215,6 +215,13 @@ ENTRY entry_computation { /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + // Test that a two dimension swap of the kernel gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { auto builder = HloComputation::Builder("entry_computation"); @@ -240,10 +247,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -293,10 +302,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -351,10 +362,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -415,10 +428,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a32d1f9026..e3328203a6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 05f90ba9fb..53b5e933b6 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -47,6 +47,12 @@ limitations under the License. namespace xla { namespace { +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} class MultiOutputFusionTest : public HloTestBase { protected: @@ -90,8 +96,8 @@ class MultiOutputFusionTest : public HloTestBase { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -154,7 +160,7 @@ class MultiOutputFusionTest : public HloTestBase { dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, - dot_dnums)); + dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 997880a018..a1001296a1 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -613,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D<float> input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillIota(1); + input.FillRandom(0.1f, 0.1f); std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); @@ -629,7 +629,14 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, auto init_value = CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); - auto computation = param.reducer == kAdd + auto reducer = param.reducer; + if (use_bfloat16() && Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + + auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); ReduceWindowWithGeneralPadding( @@ -640,8 +647,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window_strides=*/param.strides, /*padding=*/padding); - CHECK(param.reducer == kAdd || param.reducer == kMax); - auto reduce_func = param.reducer == kAdd + CHECK(reducer == kAdd || reducer == kMax); + auto reduce_func = reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; std::unique_ptr<Array4D<float>> expected = @@ -809,6 +816,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{1, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + /*window_bounds=*/{1, 64, 64, 1}, + /*strides=*/{1, 64, 64, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 0, 2, 1}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, + /*window_bounds=*/{112, 112, 1, 8}, + /*strides=*/{112, 112, 1, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -930,6 +953,27 @@ struct R3ReduceWindowTestData { {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, }; string R3ReduceWindowTestDataToString( @@ -956,35 +1000,42 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } }; -TEST_P(R3ReduceWindowTest, Add) { +TEST_P(R3ReduceWindowTest, DoIt) { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array3D<float> input(param.base_bounds[0], param.base_bounds[1], - param.base_bounds[2], 1.0f); + param.base_bounds[2]); + input.FillRandom(0.1f, 0.1f); std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); + auto reducer = param.reducer; + if (use_bfloat16()) { + input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal); + if (Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + } - XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input"); auto init_value = CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + + auto computation = reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - auto expected = ReferenceUtil::ReduceWindow3DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1093,7 +1144,6 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 66983801bf..798f499870 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -20,13 +20,7 @@ py_library( ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = if_not_windows([ - # TODO(aaroey): tensorrt dependency has to appear before tflite so the - # build can resolve its flatbuffers symbols within the tensorrt library. - # This is an issue with the tensorrt static library and will be fixed by - # the next tensorrt release, so fix the order here after that. - "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows - ]) + [ + deps = [ "//tensorflow/contrib/all_reduce", "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", @@ -135,6 +129,7 @@ py_library( ]) + if_not_windows([ "//tensorflow/contrib/bigtable", # depends on bigtable "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows + "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ]), ) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index b26c52294c..29dce13999 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates @@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base): TF equivalent, like `len`. """ - def _convert_builtin(self, node): + def _convert_builtin(self, f, args, as_expression): template = """ - ag__.utils.dynamic_builtin(func, args) + ag__.func(args) """ - return templates.replace(template, func=node.func, args=node.args)[0].value - - def _convert_print(self, node): - template = """ - ag__.utils.dynamic_print(args) - """ - return templates.replace(template, args=node.args)[0].value + if as_expression: + return templates.replace_as_expression( + template, func=py_builtins.overload_of(f).__name__, args=args) + else: + return templates.replace( + template, func=py_builtins.overload_of(f).__name__, args=args) def visit_Call(self, node): - self.generic_visit(node) - # TODO(mdan): This won't work if the function was hidden. - # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. - if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange', 'float', 'int')): - return self._convert_builtin(node) - # Print needs to be handled separately because it can be read as statement. - if isinstance(node.func, gast.Name) and node.func.id == 'print': - return self._convert_print(node) + node = self.generic_visit(node) + if anno.hasanno(node.func, 'live_val'): + live_val = anno.getanno(node.func, 'live_val') + if live_val in py_builtins.SUPPORTED_BUILTINS: + node = self._convert_builtin(live_val, node.args, as_expression=True) return node def visit_Print(self, node): - self.generic_visit(node) + node = self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts - template = """ - fname(args) - """ - function_call = templates.replace(template, fname='print', args=args)[0] - return self.visit(function_call) + return self._convert_builtin(print, args, as_expression=False) def transform(node, ctx): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index d0a0cbbeb6..3e3a04f38b 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,6 +23,7 @@ import six from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase): def test_fn(a): return len(a) - with self.converted(test_fn, builtin_functions, {'len': len}, - array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}) as result: with self.cached_session() as sess: - ops = result.test_fn(constant_op.constant([0, 0, 0])) - self.assertEqual(sess.run(ops), 3) + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + ops = result.test_fn(p) + self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3) def test_print(self): diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 276a387180..8b38d5d080 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -29,9 +29,9 @@ import six from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator @@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): - return builtins.dynamic_builtin(f, *args, **kwargs) + return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 332d5dab19..29759bad79 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", + "py_builtins.py", "slices.py", ], srcs_version = "PY2AND3", @@ -62,6 +63,16 @@ py_test( ) py_test( + name = "py_builtins_test", + srcs = ["py_builtins_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + +py_test( name = "slices_test", srcs = ["slices_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 392cb60bcc..c4fbc260a2 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.py_builtins import float_ +from tensorflow.contrib.autograph.operators.py_builtins import int_ +from tensorflow.contrib.autograph.operators.py_builtins import len_ +from tensorflow.contrib.autograph.operators.py_builtins import print_ +from tensorflow.contrib.autograph.operators.py_builtins import range_ from tensorflow.contrib.autograph.operators.slices import get_item from tensorflow.contrib.autograph.operators.slices import GetItemOpts from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 9909e52164..9a66a6bb60 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state): def _known_len_for_stmt(iter_, extra_test, body, init_state): - """Overload of for_stmt that iterates over objects that define a length.""" - n = builtins.dynamic_len(iter_) + """Overload of for_stmt that iterates over objects that admit a length.""" + n = py_builtins.len_(iter_) def while_body(iterate_index, *state): iterate = iter_[iterate_index] diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py new file mode 100644 index 0000000000..c5730934e7 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins.py @@ -0,0 +1,225 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators corresponding to Python builtin functions. + +List of built-in functions: https://docs.python.org/3/library/functions.html +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_string_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops + + +UNDEFINED = object() + + +def overload_of(f): + if f in SUPPORTED_BUILTINS: + return BUILTIN_FUINCTIONS_MAP[f.__name__] + return f + + +def abs_(x): + if tensor_util.is_tensor(x): + return _tf_abs(x) + return _py_abs(x) + + +def _tf_abs(x): + return math_ops.abs(x) + + +def _py_abs(x): + return abs(x) + + +def float_(x=0): + if tensor_util.is_tensor(x): + return _tf_float(x) + return _py_float(x) + + +def _tf_float(x): + # TODO(mdan): We shouldn't assume float32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) + return math_ops.cast(x, dtype=dtypes.float32) + + +def _py_float(x): + return float(x) + + +def int_(x=0, base=UNDEFINED): + if tensor_util.is_tensor(x): + return _tf_int(x, base) + return _py_int(x, base) + + +def _tf_int(x, base): + if base not in (10, UNDEFINED): + raise NotImplementedError('base {} not supported for int'.format(base)) + + # TODO(mdan): We shouldn't assume int32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) + return math_ops.cast(x, dtype=dtypes.int32) + + +def _py_int(x, base): + if base is UNDEFINED: + return int(x) + return int(x, base) + + +def len_(s): + if tensors.is_tensor_array(s): + return _tf_tensor_array_len(s) + elif tensors.is_tensor_list(s): + return _tf_tensor_list_len(s) + elif tensor_util.is_tensor(s): + return _tf_tensor_len(s) + return _py_len(s) + + +def _tf_tensor_array_len(s): + return s.size() + + +def _tf_tensor_list_len(s): + return list_ops.tensor_list_length(s) + + +def _tf_tensor_len(s): + """Overload of len_ for Tensor arguments.""" + # Statically shaped tensors: length is known ahead of time. + if s.shape.ndims and s.shape[0].value is not None: + return s.shape[0].value + + # Static shape of unknown dimensions: use dynamic shape but statically + # chech that it's a scalar. + shape = array_ops.shape(s) + + assert shape.shape, 'shape tensor of zero size? {}'.format(shape) + + if shape.shape[0] == 0: + raise ValueError( + 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) + + if shape.shape[0].value is not None: + return array_ops.shape(s)[0] + + # Fully dynamic shape: use ops. + rank = array_ops.rank(s) + + def raise_zero_rank_error(): + msg = gen_string_ops.string_join( + ['len requires non-zero rank, got ', + gen_string_ops.as_string(rank)]) + with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): + return constant_op.constant(0, dtype=dtypes.int32) + + return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], + raise_zero_rank_error) + + +def _py_len(s): + return len(s) + + +def print_(*objects, **kwargs): + # Note: Python 2.6 doesn't support explicit keywords after starargs. + unknown_kwargs = tuple( + set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) + if unknown_kwargs: + raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) + + # TODO(mdan): use logging_ops.Print when py_func is not supported. + return _tf_py_func_print(objects, kwargs) + + +def _tf_py_func_print(objects, kwargs): + """Overload of print_ as a py_func implementation.""" + override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED} + if 'flush' not in override_kwargs: + # Defaulting to flushing the console in graph mode, which helps reduce + # garbled output in IPython. + override_kwargs['flush'] = True + + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple( + v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) + six.print_(*vals, **override_kwargs) + + return py_func.wrap_py_func( + print_wrapper, None, objects, use_dummy_return=True) + + +def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED): + if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)): + return _tf_range(start_or_stop, stop, step) + return _py_range(start_or_stop, stop, step) + + +def _tf_range(start_or_stop, stop, step): + # TODO(mdan): We should optimize this when a full tensor is not required. + if step is not UNDEFINED: + return math_ops.range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + +def _py_range(start_or_stop, stop, step): + if step is not UNDEFINED: + return range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return range(start_or_stop, stop) + return range(start_or_stop) + + +SUPPORTED_BUILTINS = set((abs, float, int, len, print, range)) + +if six.PY2: + SUPPORTED_BUILTINS.add(xrange) + +BUILTIN_FUINCTIONS_MAP = { + 'abs': abs_, + 'float': float_, + 'int': int_, + 'len': len_, + 'print': print_, + 'range': range_, + 'xrange': range_, +} diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py new file mode 100644 index 0000000000..4073c51785 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for py_builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class PyBuiltinsTest(test.TestCase): + + def test_abs(self): + self.assertEqual(py_builtins.abs_(-1), 1) + with self.test_session() as sess: + t = py_builtins.abs_(constant_op.constant(-1)) + self.assertEqual(sess.run(t), 1) + t = py_builtins.abs_(constant_op.constant([-1, 2, -3])) + self.assertAllEqual(sess.run(t), [1, 2, 3]) + + def test_float(self): + self.assertEqual(py_builtins.float_(10), 10.0) + self.assertEqual(py_builtins.float_('10.0'), 10.0) + with self.test_session() as sess: + t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64)) + self.assertEqual(sess.run(t), 1.0) + st = py_builtins.float_(constant_op.constant('1.0')) + self.assertEqual(sess.run(st), 1.0) + + def test_int(self): + self.assertEqual(py_builtins.int_(10.0), 10) + self.assertEqual(py_builtins.int_('11', 2), 3) + with self.test_session() as sess: + t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64)) + self.assertEqual(sess.run(t), 1) + st = py_builtins.int_(constant_op.constant('1')) + self.assertEqual(sess.run(st), 1) + st = py_builtins.int_(constant_op.constant('1'), 10) + self.assertEqual(sess.run(st), 1) + + def test_int_unsupported_base(self): + t = constant_op.constant(1, dtype=dtypes.float64) + with self.assertRaises(NotImplementedError): + py_builtins.int_(t, 2) + + def test_len(self): + self.assertEqual(py_builtins.len_([1, 2, 3]), 3) + with self.test_session() as sess: + t = py_builtins.len_(constant_op.constant([[1], [2], [3]])) + self.assertEqual(t, 3) + ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5)) + self.assertEqual(sess.run(ta), 5) + tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5])) + self.assertEqual(sess.run(tl), 3) + + def test_len_scalar(self): + with self.assertRaises(ValueError): + py_builtins.len_(constant_op.constant(1)) + + def test_len_dynamic_shape(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + t = py_builtins.len_(p) + self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3) + + with self.assertRaises(errors_impl.InvalidArgumentError): + t = py_builtins.len_(p) + sess.run(t, {p: 1}) + + def test_print_tensors(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(py_builtins.print_(constant_op.constant('test message'), 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run( + py_builtins.print_(constant_op.constant('test message'), [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + def test_range(self): + self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2]) + self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2]) + self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1]) + + def test_range_tensor(self): + with self.test_session() as sess: + r = py_builtins.range_(constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [0, 1, 2]) + r = py_builtins.range_(1, constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [1, 2]) + r = py_builtins.range_(2, 0, constant_op.constant(-1)) + self.assertAllEqual(sess.run(r), [2, 1]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d2b399f19b..4504a5c7a3 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,12 +20,12 @@ py_library( name = "utils", srcs = [ "__init__.py", - "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", "py_func.py", "tensor_list.py", + "tensors.py", "testing.py", "type_check.py", ], @@ -42,17 +42,6 @@ py_library( ) py_test( - name = "builtins_test", - srcs = ["builtins_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ - ":utils", - "//tensorflow/python:client_testlib", - ], -) - -py_test( name = "context_managers_test", srcs = ["context_managers_test.py"], srcs_version = "PY2AND3", @@ -113,3 +102,13 @@ py_test( "//tensorflow/python:list_ops", ], ) + +py_test( + name = "tensors_test", + srcs = ["tensors_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py index 57b5f74741..38e0a0a8f0 100644 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin -from tensorflow.contrib.autograph.utils.builtins import dynamic_print -from tensorflow.contrib.autograph.utils.builtins import dynamic_range from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.autograph.utils.misc import alias_tensors from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py deleted file mode 100644 index 4dd440ef19..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Builtin conversion utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import py_func -from tensorflow.contrib.autograph.utils import type_check -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import list_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops - - -def dynamic_builtin(f, *args, **kwargs): - """Converts a builtin function call inline.""" - if f is len: - return dynamic_len(*args, **kwargs) - if six.PY2 and f is xrange: - return dynamic_range(*args, **kwargs) - if f is range: - return dynamic_range(*args, **kwargs) - if f is int: - return dynamic_int(*args, **kwargs) - if f is float: - return dynamic_float(*args, **kwargs) - if f is abs: - return dynamic_abs(*args, **kwargs) - - raise NotImplementedError( - 'The "%s" builtin is not yet supported.' % f.__name__) - - -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if _is_tensor_list(list_or_tensor): - return list_ops.tensor_list_length(list_or_tensor) - elif tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape.ndims: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) - - -def _is_tensor_list(list_or_tensor): - return (tensor_util.is_tensor(list_or_tensor) - and list_or_tensor.dtype == dtypes.variant) - - -def dynamic_int(num_or_tensor, **kwargs): - """Implementation of int() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) - return int(num_or_tensor) - - -def dynamic_float(num_or_tensor, **kwargs): - """Implementation of float() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) - return float(num_or_tensor) - - -def dynamic_abs(num_or_tensor, **kwargs): - if tensor_util.is_tensor(num_or_tensor): - return math_ops.abs(num_or_tensor, **kwargs) - else: - return abs(num_or_tensor, **kwargs) - - -def dynamic_range(start_or_stop, stop=None, step=None): - """Implementation of range using dynamic dispatch.""" - if type_check.is_tensor(start_or_stop, stop, step): - if step is not None: - return math_ops.range(start_or_stop, stop, step) - if stop is not None: - return math_ops.range(start_or_stop, stop) - return math_ops.range(start_or_stop) - - if step is not None: - return range(start_or_stop, stop, step) - elif stop is not None: - return range(start_or_stop, stop) - return range(start_or_stop) - - -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False - - -def dynamic_print(*values): - """Implementation of print using dynamic dispatch. - - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. - - Args: - *values: values to print - Returns: - A dummy value indicating the print completed. If tf. - """ - - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - - def print_wrapper(*vals): - if six.PY3: - # TensorFlow doesn't seem to generate Unicode when passing strings to - # py_func. This causes the print to add a "b'" wrapper to the output, - # which is probably never what you want. - vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) - print(*vals) - # The flush helps avoid garbled output in IPython. - sys.stdout.flush() - - return py_func.wrap_py_func( - print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py deleted file mode 100644 index b1cd5253bc..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for builtins module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import builtins -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import test - - -class BuiltinsTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant_op.constant(1) - - with self.assertRaisesRegexp(ValueError, - 'len requires non-zero rank for tensor.*'): - with self.test_session() as sess: - sess.run(builtins.dynamic_builtin(len, a)) - - def test_dynamic_len_tf_array(self): - a = constant_op.constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_abs_tf_scalar(self): - a = constant_op.constant(-1) - - with self.test_session() as sess: - self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) - - def test_dynamic_abs_tf_array(self): - a = constant_op.constant([-1, 2, -3]) - - with self.test_session() as sess: - self.assertListEqual([1, 2, 3], - list(sess.run(builtins.dynamic_builtin(abs, a)))) - - def test_dynamic_abs_py_scalar(self): - a = -1 - self.assertEqual(1, builtins.dynamic_builtin(abs, a)) - - def test_dynamic_len_tf_matrix(self): - a = constant_op.constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, builtins.dynamic_builtin(len, a)) - - def test_dynamic_range_all_python(self): - self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) - self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) - - def test_dynamic_range_tf(self): - with self.test_session() as sess: - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), - [0, 1, 2]) - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), - [1, 2]) - self.assertAllEqual( - sess.run( - builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), - [2, 1]) - - def test_dynamic_range_detection(self): - def range(x): # pylint:disable=redefined-builtin - return x - - # Functions that just have the names of builtins are rejected. - with self.assertRaises(NotImplementedError): - self.assertEqual(builtins.dynamic_builtin(range, 1), 1) - if six.PY2: - self.assertListEqual( - list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) - - def test_casts(self): - i = constant_op.constant(2, dtype=dtypes.int32) - f = constant_op.constant(1.0, dtype=dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) - self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, True), 1) - self.assertEqual(builtins.dynamic_builtin(int, False), 0) - self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) - self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) - - def test_dynamic_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_dynamic_print_complex(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py new file mode 100644 index 0000000000..fa5db81a71 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module defines tensor utilities not found in TensorFlow. + +The reason these utilities are not defined in TensorFlow is because they may +not be not fully robust, although they work in the vast majority of cases. So +we define them here in order for their behavior to be consistently verified. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import tensor_array_ops + + +def is_tensor_array(t): + return isinstance(t, tensor_array_ops.TensorArray) + + +def is_tensor_list(t): + # TODO(mdan): This is just a heuristic. + # With TF lacking support for templated types, this is unfortunately the + # closest we can get right now. A dedicated op ought to be possible to + # construct. + return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and + not t.shape.ndims) diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py new file mode 100644 index 0000000000..e855e0b6cb --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors_test.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensors module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class TensorsTest(test.TestCase): + + def _simple_tensor_array(self): + return tensor_array_ops.TensorArray(dtypes.int32, size=3) + + def _simple_tensor_list(self): + return list_ops.empty_tensor_list( + element_shape=constant_op.constant([1]), element_dtype=dtypes.int32) + + def _simple_list_of_tensors(self): + return [constant_op.constant(1), constant_op.constant(2)] + + def test_is_tensor_array(self): + self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array())) + self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_array(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_array(None)) + + def test_is_tensor_list(self): + self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array())) + self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_list(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_list(None)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index e6407174b1..35d727482b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -141,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # The bias is computed on gradients and hessians (and not # filtered_gradients) which have exactly one value per example, so we # don't double count a gradient in multivalent columns. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. # Bias is added to the stats even if there are no examples with values in diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 38c0a09c33..92d4251a86 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -220,6 +220,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): if weights is None: # Select inputs with uniform probability. logits = [[1.0] * num_datasets] + else: # Use the given `weights` as the probability of choosing the respective # input. @@ -245,8 +246,11 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = random_ops.RandomDataset(seed).batch(2).map( - select_dataset_constant_logits) + selector_input = dataset_ops.MapDataset( + random_ops.RandomDataset(seed).batch(2), + select_dataset_constant_logits, + use_inter_op_parallelism=False) + else: # Use each element of the given `weights` dataset as the probability of # choosing the respective input. @@ -259,9 +263,12 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2) - )).map(select_dataset_varying_logits) + logits_and_seeds = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))) + selector_input = dataset_ops.MapDataset( + logits_and_seeds, + select_dataset_varying_logits, + use_inter_op_parallelism=False) return _DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 7f09ba71dc..4c466781f7 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -499,7 +499,8 @@ def make_csv_dataset( # indefinitely, and all batches will be full-sized. dataset = dataset.batch(batch_size=batch_size, drop_remainder=num_epochs is None) - dataset = dataset.map(map_fn) + dataset = dataset_ops.MapDataset( + dataset, map_fn, use_inter_op_parallelism=False) dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -778,7 +779,8 @@ def make_batched_features_dataset(file_pattern, # Extract values if the `Example` tensors are stored as key-value tuples. if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) + dataset = dataset_ops.MapDataset( + dataset, lambda _, v: v, use_inter_op_parallelism=False) # Apply dataset repeat and shuffle transformations. dataset = _maybe_shuffle_and_repeat( diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 4fa8aa06cc..77079d0df9 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -229,6 +229,8 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = True + assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d1235b7afb..0c6805d682 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -572,6 +572,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): task_type=None, task_id=None): del task_type, task_id + + if session_config: + session_config.isolate_session_state = True + if cluster_spec: self._initialize_multi_worker(self._num_gpus, cluster_spec) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 88d7768b14..1125d027f6 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -412,6 +412,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = False + assert self._cluster_spec assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 32d7444e42..4fb70ec685 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -311,3 +311,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + del cluster_spec, task_type, task_id + if session_config: + session_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 315d7a4893..529c99b37c 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -66,7 +66,7 @@ "\n", "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", "\n", - "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", "\n", "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", "\n", @@ -128,7 +128,7 @@ "source": [ "## Download and prepare the MS-COCO dataset\n", "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", "\n", "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." ] diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index aa99616810..dcc7b71d79 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -244,6 +247,48 @@ class MetricsTest(test.TestCase): value = m.value() self.assertEqual(self.evaluate(value), 2.5) + @test_util.run_in_graph_and_eager_modes + def testGraphAndEagerTensorGlobalVariables(self): + m = metrics.Mean(use_global_variables=True) + inputs = ops.convert_to_tensor([1.0, 2.0]) + accumulate = m(inputs) + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + value = m.value() + self.assertEqual(self.evaluate(value), 2.5) + + @test_util.run_in_graph_and_eager_modes + def testGraphAndEagerTensorWhileLoopDoubleCall(self): + m = metrics.Mean() + init_value = constant_op.constant(1) + cond = lambda i: math_ops.less(i, 3) + def body(x): + with ops.control_dependencies([m(x)]): + return math_ops.add(x, 1) + accumulate = control_flow_ops.while_loop(cond, body, [init_value]) + + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + if ops.context.executing_eagerly(): + self.evaluate(control_flow_ops.while_loop(cond, body, [init_value])) + else: + # Reuse the loop operators in graph mode + self.evaluate(accumulate) + value = m.value() + self.assertEqual(self.evaluate(value), 2.0) + def testTwoMeansGraph(self): # Verify two metrics with the same name in the same graph raises a # ValueError. diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index ca46c39baa..b82bf1188f 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator): WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix factorization. It computes a low-rank approximation of a given sparse (n x m) - matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix - and V is a (m x k) matrix. Here k is the rank of the approximation, also - called the embedding dimension. We refer to U as the row factors, and V as the - column factors. + matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k) + matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation, + also called the embedding dimension. We refer to `U` as the row factors, and + `V` as the column factors. See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem formulation. - The training proceeds in sweeps: during a row_sweep, we fix V and solve for U. - During a column sweep, we fix U and solve for V. Each one of these problems is - an unconstrained quadratic minimization problem and can be solved exactly (it - can also be solved in mini-batches, since the solution decouples nicely). + The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for + `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these + problems is an unconstrained quadratic minimization problem and can be solved + exactly (it can also be solved in mini-batches, since the solution decouples + across rows of each matrix). The alternating between sweeps is achieved by using a hook during training, which is responsible for keeping track of the sweeps and running preparation ops at the beginning of each sweep. It also updates the global_step variable, which keeps track of the number of batches processed since the beginning of training. The current implementation assumes that the training is run on a single - machine, and will fail if config.num_worker_replicas is not equal to one. - Training is done by calling self.fit(input_fn=input_fn), where input_fn + machine, and will fail if `config.num_worker_replicas` is not equal to one. + Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn` provides two tensors: one for rows of the input matrix, and one for rows of the transposed input matrix (i.e. columns of the original matrix). Note that during a row sweep, only row batches are processed (ignoring column batches) and vice-versa. Also note that every row (respectively every column) of the input matrix must be processed at least once for the sweep to be considered complete. In - particular, training will not make progress if input_fn does not generate some - rows. - - For prediction, given a new set of input rows A' (e.g. new rows of the A - matrix), we compute a corresponding set of row factors U', such that U' * V^T - is a good approximation of A'. We call this operation a row projection. A - similar operation is defined for columns. - Projection is done by calling self.get_projections(input_fn=input_fn), where - input_fn satisfies the constraints given below. - - The input functions must satisfy the following constraints: Calling input_fn - must return a tuple (features, labels) where labels is None, and features is - a dict containing the following keys: + particular, training will not make progress if some rows are not generated by + the `input_fn`. + + For prediction, given a new set of input rows `A'`, we compute a corresponding + set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`. + We call this operation a row projection. A similar operation is defined for + columns. Projection is done by calling + `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the + constraints given below. + + The input functions must satisfy the following constraints: Calling `input_fn` + must return a tuple `(features, labels)` where `labels` is None, and + `features` is a dict containing the following keys: + TRAIN: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows of the input matrix to process (or to project). - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns of the input matrix to process (or to project), transposed. + INFER: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows to project. - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns to project. - - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project + * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project the rows or columns. - - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor + * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor (vector). The weights to use in the projection. + EVAL: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows to project. - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns to project. - - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project + * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project the rows or columns. """ # Keys to be used in model_fn @@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator): max_sweeps=None, model_dir=None, config=None): - """Creates a model for matrix factorization using the WALS method. + r"""Creates a model for matrix factorization using the WALS method. Args: num_rows: Total number of rows for input matrix. diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 36b483c6d7..31820a18b4 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase): nz_row_ids = np.arange(np.shape(np_matrix)[0]) nz_col_ids = np.arange(np.shape(np_matrix)[1]) - def extract_features(row_batch, col_batch, shape): + def extract_features(row_batch, col_batch, num_rows, num_cols): row_ids = row_batch[0] col_ids = col_batch[0] - rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape) - cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape) + rows = self.remap_sparse_tensor_rows( + row_batch[1], row_ids, shape=[num_rows, num_cols]) + cols = self.remap_sparse_tensor_rows( + col_batch[1], col_ids, shape=[num_cols, num_rows]) features = { wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows, wals_lib.WALSMatrixFactorization.INPUT_COLS: cols, @@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase): capacity=10, enqueue_many=True) - features = extract_features(row_batch, col_batch, sp_mat.dense_shape) + features = extract_features(row_batch, col_batch, num_rows, num_cols) if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL: self.assertTrue( diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index ab9886580d..7243f150ce 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -184,7 +184,7 @@ class GANEstimator(estimator.Estimator): return _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn) + get_hooks_fn, use_loss_summaries) super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -211,15 +211,17 @@ def _get_gan_model( def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None): + get_hooks_fn=None, use_loss_summaries=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn(gan_model), - discriminator_loss=discriminator_loss_fn(gan_model)) + generator_loss=generator_loss_fn( + gan_model, add_summaries=use_loss_summaries), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 9ac9c6ca9c..83f8dd641f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -116,7 +116,7 @@ def get_dummy_gan_model(): discriminator_fn=None) -def dummy_loss_fn(gan_model): +def dummy_loss_fn(gan_model, add_summaries=True): return math_ops.reduce_sum(gan_model.discriminator_real_outputs - gan_model.discriminator_gen_outputs) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index eee90864b4..52c9c4f3be 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1288,7 +1288,7 @@ class ConvolutionInPlaneTest(test.TestCase): result = sess.run(vert_gradients) expected = np.zeros((1, 9, 10, 1)) - self.assertAllEqual(result, expected) + self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) def testVertConvWithVaryingImage(self): image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9')) diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md deleted file mode 100644 index 8fd63d5cee..0000000000 --- a/tensorflow/contrib/lite/RELEASE.md +++ /dev/null @@ -1,8 +0,0 @@ -# Release 0.1.7 - -* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit - fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). -* To reproduce the iOS library, it's required to cherry pick git commit - f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. -* The code is based on TensorFlow 1.8.0 release candidate and it's very close - to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index fc199f0a0e..0246e7fa30 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -57,6 +57,7 @@ def tflite_linkopts_unstripped(): "-Wl,--as-needed", # Don't link unused libs. ], "//tensorflow:darwin": [], + "//tensorflow:ios": [], "//tensorflow/contrib/lite:mips": [], "//tensorflow/contrib/lite:mips64": [], "//conditions:default": [ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h index c658e43092..7c5099235a 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( } else { max_coeff = raw_input.maxCoeff(); } + + // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). + float logsumexp = 0.0; + for (int j = 0; j < raw_input.size(); ++j) { + logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); + } + logsumexp = Eigen::numext::log(logsumexp); + // Final normalization offset to get correct log probabilities. + float norm_offset = max_coeff + logsumexp; + const float label_selection_input_min = (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) : -std::numeric_limits<float>::infinity(); @@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( beam_scorer_->GetStateExpansionScore(b->state, previous)); } // Plabel(l=abc @ t=6) *= P(c @ 6) - b->newp.label += raw_input(b->label) - max_coeff; + b->newp.label += raw_input(b->label) - norm_offset; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( const float logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. + // We may compare logits instead of log probabilities, + // since the difference is the same in both cases. if (logit < label_selection_input_min) { continue; } @@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step( // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; - c.newp.label = logit - max_coeff + + c.newp.label = logit - norm_offset + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) c.newp.total = c.newp.label; diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 32458305c4..aa42b495bd 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.32134813}))); + ElementsAreArray(ArrayFloatNear({-0.357094}))); } TEST(CTCBeamSearchTest, MultiBatchTest) { @@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) { EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); // Check log probabilities output. - EXPECT_THAT( - m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958}))); } TEST(CTCBeamSearchTest, MultiPathsTest) { @@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) { EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear( - {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); + ElementsAreArray( + ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357}))); } TEST(CTCBeamSearchTest, NonEqualSequencesTest) { @@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); + ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553}))); } } // namespace diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md deleted file mode 100644 index e3db478481..0000000000 --- a/tensorflow/contrib/lite/g3doc/README.md +++ /dev/null @@ -1,4 +0,0 @@ -This is a *work-in-progress* TF Lite subsite for: -https://www.tensorflow.org/mobile - -DO NOT PUBLISH diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md deleted file mode 100644 index 70031a3c3d..0000000000 --- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md +++ /dev/null @@ -1,10 +0,0 @@ -Project: /mobile/_project.yaml -Book: /mobile/_book.yaml -page_type: reference -<style> table img { max-width: 100%; } </style> -<script src="/_static/js/managed/mathjax/MathJax.js?config=TeX-AMS-MML_SVG"></script> - -<!-- DO NOT EDIT! Automatically generated file. --> -# All symbols in TensorFlow Lite - -TEMP PAGE diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index f255017ad9..69616c7b8a 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -37,7 +37,7 @@ float* output = interpreter->typed_output_tensor<float>(0); ``` ### Data Alignment -TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended +TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended that all data provided to TensorFlow Lite be aligned that way. ### Error Reporting @@ -112,7 +112,7 @@ below. It should be noted that: * Tensors are represented by integers, in order to avoid string comparisons (and any fixed dependency on string libraries). - * An interpreter must not be accessed from concurrent threads + * An interpreter must not be accessed from concurrent threads. * Memory allocation for input and output tensors must be triggered by calling AllocateTensors() right after resizing tensors. @@ -169,7 +169,7 @@ former provides error reporting facilities and access to global objects, including all the tensors. The latter allows implementations to access their inputs and outputs. -When the interpreter loads a model, it calls init() once for each node in the +When the interpreter loads a model, it calls `init()` once for each node in the graph. A given `init()` will be called more than once if the op is used multiple times in the graph. For custom ops a configuration buffer will be provided, containing a flexbuffer that maps parameter names to their values. @@ -210,8 +210,9 @@ namespace custom { Note that registration is not automatic and an explicit call to `Register_MY_CUSTOM_OP` should be made somewhere. While the standard -`:builtin_ops` takes care of the registration of builtins, custom ops will have -to be collected in separated custom libraries. +`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the +registration of builtins, custom ops will have to be collected in separate +custom libraries. ### Customizing the kernel library @@ -232,7 +233,7 @@ class OpResolver { }; ``` -The regular usage will require the developer to use the `BuiltinOpResolver` and +Regular usage will require the developer to use the `BuiltinOpResolver` and write: ```c++ @@ -308,18 +309,25 @@ an `IllegalArgumentException` will be thrown. #### Inputs -Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of -the supported primitive types. +Each input should be an array or multi-dimensional array of the supported +primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is +an array or multi-dimensional array, the associated input tensor will be +implicitly resized to the array's dimensions at inference time. If the input is +a ByteBuffer, the caller should first manually resize the associated input +tensor (via `Interpreter.resizeInput()`) before running inference. -The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid -unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its -order must be `ByteOrder.nativeOrder()`. After it is used for a model inference, -it must remain unchanged until the model inference is finished. +When using 'ByteBuffer', prefer using direct byte buffers, as this allows the +`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte +buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a +model inference, it must remain unchanged until the model inference is finished. #### Outputs -Each output should be an array, or a multi-dimensional array of the supported -primitive types. +Each output should be an array or multi-dimensional array of the supported +primitive types, or a ByteBuffer of the appropriate size. Note that some models +have dynamic outputs, where the shape of output tensors can vary depending on +the input. There's no straightforward way of handling this with the existing +Java inference API, but planned extensions will make this possible. #### Running Model Inference @@ -339,9 +347,10 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); where each entry in `inputs` corresponds to an input tensor and `map_of_indices_to_outputs` maps indices of output tensors to the corresponding output data. In both cases the tensor indices should correspond to -the values given to the `TensorFlow Lite Optimized Converter` when the model was -created. Be aware that the order of tensors in `input` must match the order -given to the `TensorFlow Lite Optimized Converter`. +the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md) +when the model was created. Be aware that the order of tensors in `input` must +match the order given to the `TensorFlow Lite Optimized Converter`. + The Java API also provides convenient functions for app developers to get the index of any model input or output using a tensor name: diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 8287115f5c..b7c5cbf207 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android") # Suppress warnings that are introduced by Eigen Tensor. EXTRA_EIGEN_COPTS = select({ @@ -147,7 +147,7 @@ tf_cc_test( ) cc_library( - name = "builtin_ops", + name = "builtin_op_kernels", srcs = [ "activations.cc", "add.cc", @@ -177,6 +177,7 @@ cc_library( "gather.cc", "hashtable_lookup.cc", "l2norm.cc", + "layer_norm_lstm.cc", "local_response_norm.cc", "logical.cc", "lsh_projection.cc", @@ -191,7 +192,7 @@ cc_library( "pooling.cc", "pow.cc", "reduce.cc", - "register.cc", + "relu1.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", @@ -216,9 +217,9 @@ cc_library( ], hdrs = [ "padding.h", - "register.h", ], - copts = tflite_copts() + EXTRA_EIGEN_COPTS, + copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, + visibility = ["//visibility:private"], deps = [ ":activation_functor", ":eigen_support", @@ -242,6 +243,17 @@ cc_library( ], ) +cc_library( + name = "builtin_ops", + srcs = ["register.cc"], + hdrs = ["register.h"], + deps = [ + ":builtin_op_kernels", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + ], +) + tf_cc_test( name = "audio_spectrogram_test", size = "small", @@ -294,6 +306,23 @@ tf_cc_test( ) tf_cc_test( + name = "relu1_test", + size = "small", + srcs = ["relu1_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +tf_cc_test( name = "activations_test", size = "small", srcs = ["activations_test.cc"], @@ -904,6 +933,20 @@ tf_cc_test( ) tf_cc_test( + name = "layer_norm_lstm_test", + size = "small", + srcs = ["layer_norm_lstm_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +tf_cc_test( name = "lstm_test", size = "small", srcs = ["lstm_test.cc"], diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 9c891fe904..5cdd9fc94f 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -200,7 +200,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->type, output->type); const int num_dims = NumDimensions(input); - TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4); + TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4); if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -453,6 +453,19 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f); } +// Takes a 3D tensor and perform softmax along the last dimension. +void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int intermediate_size = input->dims->data[1]; + const int input_size = input->dims->data[2]; + optimized_ops::Softmax( + GetTensorData<float>(input), + GetTensorShape({batch_size, intermediate_size, 1, input_size}), + params->beta, GetTensorData<float>(output), + GetTensorShape({batch_size, intermediate_size, 1, input_size})); +} + void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation @@ -480,6 +493,19 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, GetTensorShape({batch_size, 1, 1, input_size})); } +void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + const int batch_size = input->dims->data[0]; + const int intermediate_size = input->dims->data[1]; + const int input_size = input->dims->data[2]; + optimized_ops::Softmax( + GetTensorData<uint8_t>(input), + GetTensorShape({batch_size, intermediate_size, 1, input_size}), + data->input_multiplier, data->input_left_shift, data->diff_min, + GetTensorData<uint8_t>(output), + GetTensorShape({batch_size, intermediate_size, 1, input_size})); +} + // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { @@ -515,6 +541,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DFloat(input, output, params); return kTfLiteOk; } + if (NumDimensions(input) == 3) { + Softmax3DFloat(input, output, params); + return kTfLiteOk; + } if (NumDimensions(input) == 4) { Softmax4DFloat(input, output, params); return kTfLiteOk; @@ -533,6 +563,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; } + if (NumDimensions(input) == 3) { + Softmax3DQuantized(input, output, params, data); + return kTfLiteOk; + } if (NumDimensions(input) == 4) { Softmax4DQuantized(input, output, params, data); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index e577e3a762..9fa47e190a 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { kQuantizedTolerance))); } +TEST(FloatActivationsOpTest, Softmax3D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax3D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10}); + m.SetInput<uint8_t>({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10}); + m2.SetInput<uint8_t>({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax1D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {8}}); diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index af47b33922..cde4f55a16 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -108,9 +108,26 @@ constexpr int kBwInputCellStateTensor = 38; constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; +// Temporary tensors. +enum TemporaryTensor { + // Scratch buffers for input, forget, etc. gates + kFwScratchBuffer = 0, + kBwScratchBuffer = 1, + // Quantized tensors needed for the hybrid kernel. + kInputQuantized = 2, + kFwActivationStateQuantized = 3, + kBwActivationStateQuantized = 4, + kFwCellStateQuantized = 5, + kBwCellStateQuantized = 6, + kScalingFactors = 7, + kProductScalingFactors = 8, + kRecoveredCellWeights = 9, + kNumTemporaryTensors = 10 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -131,7 +148,7 @@ TfLiteStatus CheckLstmTensorDimensions( int input_gate_bias_tensor, int forget_gate_bias_tensor, int cell_gate_bias_tensor, int output_gate_bias_tensor, int projection_weights_tensor, int projection_bias_tensor) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -324,7 +341,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input->dims->size > 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->dims->size, 3); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -370,11 +388,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output, fw_output_size)); - // Create a scratch buffer tensor. + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); - node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. + } + // Create a scratch buffer tensor. + node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; + TfLiteTensor* fw_scratch_buffer = + GetTemporary(context, node, kFwScratchBuffer); fw_scratch_buffer->type = input->type; fw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -435,8 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell); // Create a scratch buffer tensor. - node->temporaries->data[1] = *(scratch_tensor_index) + 1; - TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1); + node->temporaries->data[kBwScratchBuffer] = + *(scratch_tensor_index) + kBwScratchBuffer; + TfLiteTensor* bw_scratch_buffer = + GetTemporary(context, node, kBwScratchBuffer); bw_scratch_buffer->type = input->type; bw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -454,18 +482,441 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + + node->temporaries->data[kFwActivationStateQuantized] = + *scratch_tensor_index + kFwActivationStateQuantized; + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + fw_activation_state_quantized->type = kTfLiteUInt8; + fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, + fw_activation_state->dims)) { + TfLiteIntArray* fw_activation_state_quantized_size = + TfLiteIntArrayCopy(fw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_activation_state_quantized, + fw_activation_state_quantized_size)); + } + node->temporaries->data[kBwActivationStateQuantized] = + *scratch_tensor_index + kBwActivationStateQuantized; + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + bw_activation_state_quantized->type = kTfLiteUInt8; + bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, + bw_activation_state->dims)) { + TfLiteIntArray* bw_activation_state_quantized_size = + TfLiteIntArrayCopy(bw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_activation_state_quantized, + bw_activation_state_quantized_size)); + } + node->temporaries->data[kFwCellStateQuantized] = + *scratch_tensor_index + kFwCellStateQuantized; + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + fw_cell_state_quantized->type = kTfLiteUInt8; + fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, + fw_cell_state->dims)) { + TfLiteIntArray* fw_cell_state_quantized_size = + TfLiteIntArrayCopy(fw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, fw_cell_state_quantized, + fw_cell_state_quantized_size)); + } + node->temporaries->data[kBwCellStateQuantized] = + *scratch_tensor_index + kBwCellStateQuantized; + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + bw_cell_state_quantized->type = kTfLiteUInt8; + bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, + bw_cell_state->dims)) { + TfLiteIntArray* bw_cell_state_quantized_size = + TfLiteIntArrayCopy(bw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, bw_cell_state_quantized, + bw_cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_fw_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Loop through the sequence. + if (forward_sequence) { + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* output_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast<int8_t*>(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast<int8_t*>(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast<int8_t*>(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + if (forward_sequence) { + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } + return kTfLiteOk; } // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; // Tensors for the forward cell. const TfLiteTensor* fw_input_to_input_weights = @@ -559,149 +1010,91 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwInputCellStateTensor); TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - // n_cell and n_output will be the same size when there is no projection. - const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; - const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); - const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* fw_scratch_buffer = - &context->tensors[node->temporaries->data[0]]; - float* fw_input_gate_scratch = nullptr; - float* fw_cell_scratch = nullptr; - float* fw_forget_gate_scratch = nullptr; - float* fw_output_gate_scratch = nullptr; - if (fw_use_cifg) { - fw_cell_scratch = fw_scratch_buffer->data.f; - fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - } else { - fw_input_gate_scratch = fw_scratch_buffer->data.f; - fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_forget_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* fw_input_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f; - const float* fw_recurrent_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f; - const float* fw_input_gate_bias_ptr = - (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f; - const float* fw_cell_to_input_weights_ptr = - (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f - : nullptr; - const float* fw_cell_to_forget_weights_ptr = - (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr; - const float* fw_cell_to_output_weights_ptr = - (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr; - const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr) - ? nullptr - : fw_projection_weights->data.f; - const float* fw_projection_bias_ptr = - (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f; - - // Loop through the sequence. - for (int t = 0; t < max_time; t++) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; - - kernel_utils::LstmStep( - input_ptr_batch, fw_input_to_input_weights_ptr, - fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, - fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, - fw_recurrent_to_forget_weights->data.f, - fw_recurrent_to_cell_weights->data.f, - fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr, - fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, - fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, - fw_cell_bias->data.f, fw_output_gate_bias->data.f, - fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch, - n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f, - fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch, - fw_cell_scratch, fw_output_gate_scratch, output_ptr_time); - } - - // n_cell and n_output will be the same size when there is no projection. - const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; - const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); - const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. + GetTemporary(context, node, kFwScratchBuffer); TfLiteTensor* bw_scratch_buffer = - &context->tensors[node->temporaries->data[1]]; - float* bw_input_gate_scratch = nullptr; - float* bw_cell_scratch = nullptr; - float* bw_forget_gate_scratch = nullptr; - float* bw_output_gate_scratch = nullptr; - if (bw_use_cifg) { - bw_cell_scratch = bw_scratch_buffer->data.f; - bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - } else { - bw_input_gate_scratch = bw_scratch_buffer->data.f; - bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_forget_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch; + GetTemporary(context, node, kBwScratchBuffer); + + switch (fw_input_to_output_weights->type) { + case kTfLiteFloat32: { + TfLiteStatus fw_pass_status = EvalFloat( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalFloat( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + TfLiteStatus fw_pass_status = EvalHybrid( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + fw_activation_state_quantized, fw_cell_state_quantized, + fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalHybrid( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + bw_activation_state_quantized, bw_cell_state_quantized, + bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + default: + context->ReportError(context, "Type %d is not currently supported.", + fw_input_to_output_weights->type); + return kTfLiteError; } - - // Check optional tensors, the respective pointers can be null. - const float* bw_input_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f; - const float* bw_recurrent_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f; - const float* bw_input_gate_bias_ptr = - (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f; - const float* bw_cell_to_input_weights_ptr = - (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f - : nullptr; - const float* bw_cell_to_forget_weights_ptr = - (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr; - const float* bw_cell_to_output_weights_ptr = - (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr; - const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr) - ? nullptr - : bw_projection_weights->data.f; - const float* bw_projection_bias_ptr = - (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f; - - // Loop through the sequence backwards. - for (int t = max_time - 1; t >= 0; t--) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; - - kernel_utils::LstmStep( - input_ptr_batch, bw_input_to_input_weights_ptr, - bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, - bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, - bw_recurrent_to_forget_weights->data.f, - bw_recurrent_to_cell_weights->data.f, - bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr, - bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, - bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, - bw_cell_bias->data.f, bw_output_gate_bias->data.f, - bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch, - n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f, - bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch, - bw_cell_scratch, bw_output_gate_scratch, output_ptr_time); - } - - // Backward step. return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index e671624fe7..5ca1b4b76f 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -79,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1, n_batch, result, result_stride); } +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); +} + void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); @@ -138,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector, reduction_size); } +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon) { + PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch, + normalization_epsilon); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 70adffda3b..2c8e8f90e3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -43,6 +43,14 @@ namespace optimized_ops { // Unoptimized reference ops: using reference_ops::ArgMax; using reference_ops::ArgMinMax; +using reference_ops::Broadcast4DSlowGreater; +using reference_ops::Broadcast4DSlowGreaterEqual; +using reference_ops::Broadcast4DSlowGreaterEqualWithScaling; +using reference_ops::Broadcast4DSlowGreaterWithScaling; +using reference_ops::Broadcast4DSlowLess; +using reference_ops::Broadcast4DSlowLessEqual; +using reference_ops::Broadcast4DSlowLessEqualWithScaling; +using reference_ops::Broadcast4DSlowLessWithScaling; using reference_ops::BroadcastAdd4DSlow; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; @@ -58,8 +66,12 @@ using reference_ops::FakeQuant; using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; +using reference_ops::GreaterEqualWithScaling; +using reference_ops::GreaterWithScaling; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::LessEqualWithScaling; +using reference_ops::LessWithScaling; using reference_ops::Mean; using reference_ops::RankOneSelect; using reference_ops::Relu1; @@ -67,6 +79,7 @@ using reference_ops::Relu6; using reference_ops::ReluX; using reference_ops::Select; using reference_ops::SpaceToBatchND; +using reference_ops::Split; using reference_ops::StridedSlice; using reference_ops::Transpose; diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index 8664ebc4f6..7e53dc2fa2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -117,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit, void NeonClipVector(const float* vector, int v_size, float abs_limit, float* result); +// Add another vector for each batch in the batch vector. +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Batch vector initialization with another vector. void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); @@ -172,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, void NeonReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon); + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index f882f9910e..544ef16ce1 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -23,6 +23,32 @@ limitations under the License. namespace tflite { +namespace { +// These constants are used to manipulate the binary representation of doubles. +// Double-precision binary64 floating point format is: +// Bit | 63 | 62-52 | 51-0 | +// | Sign | Exponent | Fraction | +// To avoid 64-bit integers as much as possible, I break this into high and +// low 32-bit chunks. High is: +// Bit | 31 | 30-20 | 19-0 | +// | Sign | Exponent | High Fraction | +// Low is: +// Bit | 31-0 | +// | Low Fraction | +// We then access the components through logical bit-wise operations to +// extract the parts needed, with the positions and masks derived from the +// layout shown above. +constexpr uint64_t kSignMask = 0x8000000000000000LL; +constexpr uint64_t kExponentMask = 0x7ff0000000000000LL; +constexpr int32_t kExponentShift = 52; +constexpr int32_t kExponentBias = 1023; +constexpr uint32_t kExponentIsBadNum = 0x7ff; +constexpr uint64_t kFractionMask = 0x000fffffffc00000LL; +constexpr uint32_t kFractionShift = 22; +constexpr uint32_t kFractionRoundingMask = 0x003fffff; +constexpr uint32_t kFractionRoundingThreshold = 0x00200000; +} // namespace + void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift) { if (double_multiplier == 0.) { @@ -30,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, *shift = 0; return; } +#ifdef TFLITE_EMULATE_FLOAT + // If we're trying to avoid the use of floating-point instructions (for + // example on microcontrollers) then use an alternative implementation + // that only requires integer and bitwise operations. To enable this, you + // need to set the define during the build process for your platform. + int64_t q_fixed = IntegerFrExp(double_multiplier, shift); +#else // TFLITE_EMULATE_FLOAT const double q = std::frexp(double_multiplier, shift); auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31))); +#endif // TFLITE_EMULATE_FLOAT TFLITE_CHECK(q_fixed <= (1ll << 31)); if (q_fixed == (1ll << 31)) { q_fixed /= 2; @@ -60,6 +94,163 @@ void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, *left_shift = shift; } +int64_t IntegerFrExp(double input, int* shift) { + // Make sure our assumptions about the double layout hold. + TFLITE_CHECK_EQ(8, sizeof(double)); + + // We want to access the bits of the input double value directly, which is + // tricky to do safely, so use a union to handle the casting. + union { + double double_value; + uint64_t double_as_uint; + } cast_union; + cast_union.double_value = input; + const uint64_t u = cast_union.double_as_uint; + + // If the bitfield is all zeros apart from the sign bit, this is a normalized + // zero value, so return standard values for this special case. + if ((u & ~kSignMask) == 0) { + *shift = 0; + return 0; + } + + // Deal with NaNs and Infs, which are always indicated with a fixed pattern in + // the exponent, and distinguished by whether the fractions are zero or + // non-zero. + const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift); + if (exponent_part == kExponentIsBadNum) { + *shift = std::numeric_limits<int>::max(); + if (u & kFractionMask) { + // NaN, so just return zero (with the exponent set to INT_MAX). + return 0; + } else { + // Infinity, so return +/- INT_MAX. + if (u & kSignMask) { + return std::numeric_limits<int64_t>::min(); + } else { + return std::numeric_limits<int64_t>::max(); + } + } + } + + // The shift is fairly easy to extract from the high bits of the double value, + // just by masking it out and applying a bias. The std::frexp() implementation + // always returns values between 0.5 and 1.0 though, whereas the exponent + // assumes 1.0 to 2.0 is the standard range, so I add on one to match that + // interface. + *shift = (exponent_part - kExponentBias) + 1; + + // There's an implicit high bit in the double format definition, so make sure + // we include that at the top, and then reconstruct the rest of the fractional + // value from the remaining fragments. + int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift); + + // We're cutting off some bits at the bottom, so to exactly match the standard + // frexp implementation here we'll apply rounding by adding one to the least + // significant bit of the result if the discarded portion is over half of the + // maximum. + if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) { + fraction += 1; + } + // Negate the fraction if the sign bit was set. + if (u & kSignMask) { + fraction *= -1; + } + + return fraction; +} + +double DoubleFromFractionAndShift(int64_t fraction, int shift) { + union { + double double_value; + uint64_t double_as_uint; + } result; + + // Detect NaNs and infinities. + if (shift == std::numeric_limits<int>::max()) { + if (fraction == 0) { + return NAN; + } else if (fraction > 0) { + return INFINITY; + } else { + return -INFINITY; + } + } + + // Return a normalized zero for a zero fraction. + if (fraction == 0) { + result.double_as_uint = 0; + return result.double_value; + } + + bool is_negative = (fraction < 0); + int64_t encoded_fraction = is_negative ? -fraction : fraction; + int64_t encoded_shift = (shift - 1); + while (encoded_fraction < 0x40000000) { + encoded_fraction *= 2; + encoded_shift -= 1; + } + while (encoded_fraction > 0x80000000) { + encoded_fraction /= 2; + encoded_shift += 1; + } + encoded_fraction -= 0x40000000; + if (encoded_shift < -1022) { + encoded_shift = -1023; + } else if (encoded_shift > 1022) { + encoded_shift = 1023; + } + encoded_shift += kExponentBias; + uint64_t encoded_sign = is_negative ? kSignMask : 0; + result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) | + (encoded_fraction << kFractionShift); + return result.double_value; +} + +double IntegerDoubleMultiply(double a, double b) { + int a_shift; + const int64_t a_fraction = IntegerFrExp(a, &a_shift); + int b_shift; + const int64_t b_fraction = IntegerFrExp(b, &b_shift); + // Detect NaNs and infinities. + if (a_shift == std::numeric_limits<int>::max() || + (b_shift == std::numeric_limits<int>::max())) { + return NAN; + } + const int result_shift = a_shift + b_shift + 1; + const int64_t result_fraction = (a_fraction * b_fraction) >> 32; + return DoubleFromFractionAndShift(result_fraction, result_shift); +} + +int IntegerDoubleCompare(double a, double b) { + int a_shift; + const int64_t a_fraction = IntegerFrExp(a, &a_shift); + int b_shift; + const int64_t b_fraction = IntegerFrExp(b, &b_shift); + + // Detect NaNs and infinities. + if (a_shift == std::numeric_limits<int>::max() || + (b_shift == std::numeric_limits<int>::max())) { + return 1; + } + + if ((a_fraction == 0) && (b_fraction < 0)) { + return 1; + } else if ((a_fraction < 0) && (b_fraction == 0)) { + return -1; + } else if (a_shift < b_shift) { + return -1; + } else if (a_shift > b_shift) { + return 1; + } else if (a_fraction < b_fraction) { + return -1; + } else if (a_fraction > b_fraction) { + return 1; + } else { + return 0; + } +} + void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift) { @@ -72,8 +263,20 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, // result is double equivalent of Q0.31 (actually with more precision). Thus // this generates a Q(input_integer_bits).(31-input_integer_bits) // representation. +#ifdef TFLITE_EMULATE_FLOAT + const double input_beta = IntegerDoubleMultiply(beta, input_scale); + int shift; + int64_t fraction = IntegerFrExp(input_beta, &shift); + shift += (31 - input_integer_bits); + double input_beta_real_multiplier = + DoubleFromFractionAndShift(fraction, shift); + if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) { + input_beta_real_multiplier = (1ll << 31) - 1.0; + } +#else // TFLITE_EMULATE_FLOAT const double input_beta_real_multiplier = std::min( beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); +#endif // TFLITE_EMULATE_FLOAT QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, quantized_multiplier, left_shift); @@ -97,6 +300,12 @@ void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { +#ifdef TFLITE_EMULATE_FLOAT + int64_t result = (1 << input_integer_bits) - 1; + result <<= (31 - input_integer_bits); + result >>= input_left_shift; + return result; +#else // TFLITE_EMULATE_FLOAT const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * (1ll << (31 - input_integer_bits)) / (1ll << input_left_shift); @@ -104,6 +313,7 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) { // After scaling the difference, the result would be at the maximum. Thus we // must ensure that our value has lower magnitude. return static_cast<int>(std::floor(max_input_rescaled)); +#endif // TFLITE_EMULATE_FLOAT } void NudgeQuantizationRange(const float min, const float max, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 9ee4a47fbb..d74a1bac97 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -195,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift); +// Splits a double input value into a returned fraction, and a shift value from +// the exponent, using only bitwise and integer operations to support +// microcontrollers and other environments without floating-point support. +// +// This is designed to be a replacement for how std::frexp() is used within the +// QuantizeMultiplier() function, and so has a different signature than the +// standard version, returning a 64-bit integer rather than a double. This +// result has a maximum value of 1<<31, with the fraction expressed as a +// proportion of that maximum. +// +// std::frexp() returns NaNs and infinities unmodified, but since we're +// returning integers that can't represent those values, instead we return +// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64 +// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and +// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will +// result in return values that end up truncating some bits at the end, +// reflecting the loss of precision inherent in denormalization. +int64_t IntegerFrExp(double input, int* shift); + +// Converts an integer fraction in the format produced by IntegerFrExp (where +// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an +// IEEE binary64 double format result. The implementation uses only integer and +// bitwise operators, so no floating point hardware support or emulation is +// needed. This is here so quantized operations can run non-time-critical +// preparation calculations on microcontrollers and other platforms without +// float support. +double DoubleFromFractionAndShift(int64_t fraction, int shift); + +// Performs a multiplication of two numbers in double format, using only integer +// and bitwise instructions. This is aimed at supporting housekeeping functions +// for quantized operations on microcontrollers without floating-point hardware. +double IntegerDoubleMultiply(double a, double b); + +// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is +// greater than b. It is implemented using only integer and logical instructions +// so that it can be easily run on microcontrollers for quantized operations. +int IntegerDoubleCompare(double a, double b); + // This first creates a multiplier in a double equivalent of // Q(input_integer_bits).(31-input_integer_bits) representation, with extra // precision in the double's fractional bits. It then splits the result into diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 00fc3e91dc..14281f25c6 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -191,6 +191,139 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { EXPECT_EQ(qp.zero_point, 255); } +TEST(QuantizationUtilTest, IntegerFrExp) { + int shift; + int64_t result = IntegerFrExp(0.0, &shift); + EXPECT_EQ(0, result); + EXPECT_EQ(0, shift); + + result = IntegerFrExp(1.0, &shift); + EXPECT_NEAR(0x40000000, result, 1); + EXPECT_EQ(1, shift); + + result = IntegerFrExp(0.25, &shift); + EXPECT_NEAR(0x40000000, result, 1); + EXPECT_EQ(-1, shift); + + result = IntegerFrExp(-1.0, &shift); + EXPECT_NEAR(-(1 << 30), result, 1); + EXPECT_EQ(1, shift); + + result = IntegerFrExp(123.45, &shift); + EXPECT_NEAR(2071147315, result, 1); + EXPECT_EQ(7, shift); + + result = IntegerFrExp(NAN, &shift); + EXPECT_NEAR(0, result, 1); + EXPECT_EQ(0x7fffffff, shift); + + result = IntegerFrExp(INFINITY, &shift); + EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1); + EXPECT_EQ(0x7fffffff, shift); + + result = IntegerFrExp(-INFINITY, &shift); + EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1); + EXPECT_EQ(0x7fffffff, shift); +} + +TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) { + int shift; + int32_t result = IntegerFrExp(0.0, &shift); + EXPECT_EQ(result, 0); + EXPECT_EQ(shift, 0); + + int double_shift; + double double_result = std::frexp(0.0, &double_shift); + EXPECT_EQ(double_result, 0); + EXPECT_EQ(double_shift, 0); + + result = IntegerFrExp(1.0, &shift); + EXPECT_NEAR(result, 0x40000000, 1); + EXPECT_EQ(shift, 1); + double_result = std::frexp(1.0, &double_shift); + EXPECT_NEAR(double_result, 0.5, 1e-5); + EXPECT_EQ(double_shift, 1); + + result = IntegerFrExp(0.25, &shift); + EXPECT_NEAR(result, 0x40000000, 1); + EXPECT_EQ(shift, -1); + double_result = std::frexp(0.25, &double_shift); + EXPECT_NEAR(double_result, 0.5, 1e-5); + EXPECT_EQ(double_shift, -1); + + result = IntegerFrExp(-1.0, &shift); + EXPECT_NEAR(result, -(1 << 30), 1); + EXPECT_EQ(shift, 1); + double_result = std::frexp(-1.0, &double_shift); + EXPECT_NEAR(double_result, -0.5, 1e-5); + EXPECT_EQ(double_shift, 1); + + result = IntegerFrExp(123.45, &shift); + EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000); + EXPECT_EQ(shift, 7); + double_result = std::frexp(123.45, &double_shift); + EXPECT_NEAR(double_result, 0.964453, 1e-5); + EXPECT_EQ(double_shift, 7); +} + +TEST(QuantizationUtilTest, DoubleFromFractionAndShift) { + double result = DoubleFromFractionAndShift(0, 0); + EXPECT_EQ(0, result); + + result = DoubleFromFractionAndShift(0x40000000, 1); + EXPECT_NEAR(1.0, result, 1e-5); + + result = DoubleFromFractionAndShift(0x40000000, 2); + EXPECT_NEAR(2.0, result, 1e-5); + + int shift; + int64_t fraction = IntegerFrExp(3.0, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(3.0, result, 1e-5); + + fraction = IntegerFrExp(123.45, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(123.45, result, 1e-5); + + fraction = IntegerFrExp(-23.232323, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(-23.232323, result, 1e-5); + + fraction = IntegerFrExp(NAN, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_TRUE(std::isnan(result)); + + fraction = IntegerFrExp(INFINITY, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_FALSE(std::isfinite(result)); +} + +TEST(QuantizationUtilTest, IntegerDoubleMultiply) { + EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5); + EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5); + EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5); + EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5); + EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5); + EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5); + EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5); + EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5); + EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5); + EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5); + EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0))); + EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN))); +} + +TEST(QuantizationUtilTest, IntegerDoubleCompare) { + EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0)); + EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0)); + EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0)); + EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0)); + EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0)); + EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0)); + EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY)); + EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN)); +} + #ifdef GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), ""); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index e79e75a898..2a30910c3f 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -173,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, } } +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + for (int i = 0; i < v_size; ++i) { + batch_vector[i] += vector[i]; + } + batch_vector += v_size; + } +} + void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { for (int b = 0; b < n_batch; b++) { @@ -243,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, } } +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon) { + for (int batch = 0; batch < n_batch; ++batch) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (int i = 0; i < v_size; ++i) { + sum += input_vector[i]; + sum_sq += input_vector[i] * input_vector[i]; + } + const float mean = sum / v_size; + float stddev_inv = 0.0f; + const float variance = sum_sq / v_size - mean * mean; + if (variance == 0) { + stddev_inv = 1.0f / sqrt(normalization_epsilon); + } else { + stddev_inv = 1.0f / sqrt(variance); + } + for (int i = 0; i < v_size; ++i) { + output_vector[i] = (input_vector[i] - mean) * stddev_inv; + } + input_vector += v_size; + output_vector += v_size; + } +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 3829be0c5e..f5b3a84f07 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -87,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); +// Add another vector for each batch in the batch vector. +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Apply sigmoid to elements of a vector. void PortableApplySigmoidToVector(const float* vector, int v_size, float* result); @@ -125,6 +129,12 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); void PortableReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); +// Layer norm for each batch. +// normalization_epsilon is added to avoid divergence. +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon); + float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } bool IsZeroVector(const float* vector, int v_size) { @@ -193,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1, result, result_stride); } +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); +} + void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); @@ -240,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector, reduction_size); } +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon) { + PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch, + normalization_epsilon); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 62f7ade7d5..00f9616cc2 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2524,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, } template <typename Scalar> +void Split(const SplitParams& params, const RuntimeShape& input_shape, + const Scalar* input_data, const RuntimeShape* const* output_shapes, + Scalar* const* output_data) { + const int concat_dimensions = input_shape.DimensionsCount(); + int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis; + int outputs_count = params.num_split; + TFLITE_DCHECK_LT(axis, concat_dimensions); + + int64_t concat_size = 0; + for (int i = 0; i < outputs_count; i++) { + TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions); + for (int j = 0; j < concat_dimensions; j++) { + if (j != axis) { + MatchingDim(*output_shapes[i], j, input_shape, j); + } + } + concat_size += output_shapes[i]->Dims(axis); + } + TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis)); + int64_t outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= input_shape.Dims(i); + } + // For all output arrays, + // FlatSize() = outer_size * Dims(axis) * base_inner_size; + int64_t base_inner_size = 1; + for (int i = axis + 1; i < concat_dimensions; ++i) { + base_inner_size *= input_shape.Dims(i); + } + + const Scalar* input_ptr = input_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < outputs_count; ++i) { + const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size; + memcpy(output_data[i] + k * copy_size, input_ptr, + copy_size * sizeof(Scalar)); + input_ptr += copy_size; + } + } +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +template <typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int axis, int outputs_count, Scalar* const* output_data, const Dims<4>* const* output_dims) { - const int batches = ArraySize(*output_dims[0], 3); - const int height = ArraySize(*output_dims[0], 2); - const int width = ArraySize(*output_dims[0], 1); - const int depth = ArraySize(*output_dims[0], 0); - - const int slice_size = ArraySize(*output_dims[0], axis); - + std::vector<RuntimeShape> output_shapes(outputs_count); + std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count); for (int i = 0; i < outputs_count; ++i) { - int offset = i * slice_size * input_dims.strides[axis]; - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - auto out = Offset(*output_dims[i], c, x, y, b); - auto in = Offset(input_dims, c, x, y, b); - output_data[i][out] = input_data[offset + in]; - } - } - } - } + ShapeFromDims(*output_dims[i], &output_shapes[i]); + output_shapes_indirect[i] = &output_shapes[i]; } + tflite::SplitParams op_params; + op_params.axis = 3 - axis; + op_params.num_split = outputs_count; + + Split(op_params, DimsToShape(input_dims), input_data, + output_shapes_indirect.data(), output_data); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. template <FusedActivationFunctionType Ac, typename Scalar> void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2560,9 +2597,8 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); } - // for now we dont have a model with a TensorFlowSplit - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + // For now we don't have a model with a Split with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count, output_data, output_dims); @@ -3416,23 +3452,55 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data, } template <typename T> -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; +inline void Gather(const tflite::GatherParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& coords_shape, const int32* coords_data, + const RuntimeShape& output_shape, T* output_data) { + // TODO(b/80418076): Enable these checks when moving legacy ops to + // legacy_reference_ops. + // + // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1); + const int input_rank = op_params.input_rank; + const int gather_dimensions = output_shape.DimensionsCount(); + TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions); + const int axis = gather_dimensions - input_rank; + TFLITE_DCHECK_LT(axis, gather_dimensions); + TFLITE_DCHECK_GE(axis, 0); + const int coords_count = coords_shape.FlatSize(); + TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis)); + + int64_t stride = 1; + for (int i = axis + 1; i < gather_dimensions; ++i) { + stride *= input_shape.Dims(i); + } T* out = output_data; - for (int i = 0; i < coords_dims.sizes[0]; i++) { + for (int i = 0; i < coords_count; ++i) { TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis)); const T* in = input_data + coords_data[i] * stride; memcpy(out, in, sizeof(T) * stride); out += stride; } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4> version. +// When moving legacy ops to legacy_reference_ops, replace content with looser +// implementation. +template <typename T> +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::GatherParams op_params; + op_params.input_rank = input_rank; + + Gather(op_params, DimsToShape(input_dims), input_data, + DimsToShape(coords_dims), coords_data, DimsToShape(output_dims), + output_data); +} + template <typename T> inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, @@ -4301,9 +4369,10 @@ template <typename T> using ComparisonFn = bool (*)(T, T); template <typename T, ComparisonFn<T> F> -inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, bool* output_data) { +inline void ComparisonImpl( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { const int64_t flatsize = MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int64_t i = 0; i < flatsize; ++i) { @@ -4311,25 +4380,45 @@ inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data, } } +template <ComparisonFn<float> F> +inline void Comparison(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, bool* output_data) { + ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. template <typename T, ComparisonFn<T> F> inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, bool* output_data, const Dims<4>& output_dims) { - Comparison<T, F>(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); + ComparisonParams op_params; + // No parameters needed. + ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } template <typename T, ComparisonFn<int32> F> -inline void Comparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const T* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, bool* output_data, - const Dims<4>& output_dims) { +inline void ComparisonWithScaling( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + const int64_t flatsize = - MatchingFlatSize(input1_dims, input2_dims, output_dims); + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int64_t i = 0; i < flatsize; ++i) { const int32 input1_val = input1_offset + input1_data[i]; const int32 input2_val = input2_offset + input2_data[i]; @@ -4337,68 +4426,140 @@ inline void Comparison(int left_shift, const T* input1_data, const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, input1_multiplier, input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, input2_multiplier, input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T, ComparisonFn<int32> F> +inline void Comparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, bool* output_data, + const Dims<4>& output_dims) { + tflite::ComparisonParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + + ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + template <typename T, ComparisonFn<T> F> -inline void BroadcastComparison(const T* input1_data, - const Dims<4>& input1_dims, - const T* input2_data, - const Dims<4>& input2_dims, bool* output_data, - const Dims<4>& output_dims) { +inline void BroadcastComparison4DSlowImpl( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - F(input1_data[SubscriptToIndex(desc1, c, x, y, b)], - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + output_data[Offset(output_shape, b, y, x, c)] = + F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]); } } } } } +template <ComparisonFn<float> F> +inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + bool* output_data) { + BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data, + input2_shape, input2_data, + output_shape, output_data); +} -template <typename T, ComparisonFn<int32> F> -inline void BroadcastComparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T, ComparisonFn<T> F> +inline void BroadcastComparison(const T* input1_data, + const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 input2_multiplier, int input2_shift, - bool* output_data, const Dims<4>& output_dims) { + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims) { + ComparisonParams op_params; + // No parameters needed. + BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims), + input1_data, DimsToShape(input2_dims), + input2_data, DimsToShape(output_dims), + output_data); +} + +template <typename T, ComparisonFn<int32> F> +inline void BroadcastComparison4DSlowWithScaling( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, input1_multiplier, input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - output_data[Offset(output_dims, c, x, y, b)] = + shifted_input2_val, input2_multiplier, input2_shift); + output_data[Offset(output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val); } } @@ -4406,51 +4567,117 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, } } -#define TFLITE_COMPARISON_OP(name) \ - template <typename T> \ - inline void name(const T* input1_data, const Dims<4>& input1_dims, \ - const T* input2_data, const Dims<4>& input2_dims, \ - bool* output_data, const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label(#name); \ - Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ - input2_dims, output_data, output_dims); \ - } \ - template <typename T> \ - inline void name( \ - int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ - int32 input1_offset, int32 input1_multiplier, int input1_shift, \ - const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ - int32 input2_multiplier, int input2_shift, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ - Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ - input1_offset, input1_multiplier, input1_shift, \ - input2_data, input2_dims, input2_offset, \ - input2_multiplier, input2_shift, output_data, \ - output_dims); \ - } \ - template <typename T> \ - inline void Broadcast##name( \ - const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ - const Dims<4>& input2_dims, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ - BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ - input2_dims, output_data, output_dims); \ - } \ - template <typename T> \ - inline void Broadcast##name( \ - int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ - int32 input1_offset, int32 input1_multiplier, int input1_shift, \ - const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ - int32 input2_multiplier, int input2_shift, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ - BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ - input1_offset, input1_multiplier, \ - input1_shift, input2_data, input2_dims, \ - input2_offset, input2_multiplier, \ - input2_shift, output_data, output_dims); \ +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T, ComparisonFn<int32> F> +inline void BroadcastComparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 input2_multiplier, int input2_shift, + bool* output_data, const Dims<4>& output_dims) { + ComparisonParams op_params; + + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + + BroadcastComparison4DSlowWithScaling<T, F>( + op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +#define TFLITE_COMPARISON_OP(name) \ + template <typename T> \ + inline void name(const T* input1_data, const Dims<4>& input1_dims, \ + const T* input2_data, const Dims<4>& input2_dims, \ + bool* output_data, const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, input1_shift, \ + input2_data, input2_dims, input2_offset, \ + input2_multiplier, input2_shift, output_data, \ + output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ + const Dims<4>& input2_dims, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template <typename T> \ + inline void Broadcast##name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } \ + inline void name(const ComparisonParams& op_params, \ + const RuntimeShape& input1_shape, const float* input1_data, \ + const RuntimeShape& input2_shape, const float* input2_data, \ + const RuntimeShape& output_shape, bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \ + input2_data, output_shape, output_data); \ + } \ + template <typename T> \ + inline void name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + inline void Broadcast4DSlow##name( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const float* input1_data, const RuntimeShape& input2_shape, \ + const float* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + template <typename T> \ + inline void Broadcast4DSlow##name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison4DSlowWithScaling<T, name##Fn>( \ + op_params, input1_shape, input1_data, input2_shape, input2_data, \ + output_shape, output_data); \ } TFLITE_COMPARISON_OP(Equal); TFLITE_COMPARISON_OP(NotEqual); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 748356d1bd..1439bf8c37 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -113,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, const float* batch_vector, int n_batch, float* result); +// Add another vector for each batch in the batch vector. +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Batch vector initialization with another vector. void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); @@ -152,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value); // added to get one element of output. void ReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); + +// Layer norm for each batch. +// normalization_epsilon is added to avoid divergence. +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon); } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 240fb64ca3..dad924fc28 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -496,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); } +TEST(uKernels, VectorBatchVectorAddTest) { + constexpr int kVectorSize = 3; + constexpr int kBatchSize = 2; + static float input[kVectorSize] = {0.0, -0.5, 1.0}; + std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, + testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0})); +} + TEST(uKernels, VectorBatchVectorAssignTest) { constexpr int kVectorSize = 5; constexpr int kBatchSize = 3; @@ -712,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) { EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); } +TEST(uKernels, MeanStddevNormalizationNoneZeroInput) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // None-zero input. + static float input[kVectorSize * kBatchSize] = { + 0.1, 0.2, 0.3, 0.4, // batch 0 + 0.9, 1.0, 1.1, 1.2, // batch 1 + }; + std::vector<float> output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector<float> expected_output = { + -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0 + -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationAllZeroInput) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Zero input. + static float input[kVectorSize * kBatchSize] = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.0, 0.0, 0.0, 0.0, // batch 1 + }; + std::vector<float> output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector<float> expected_output = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.0, 0.0, 0.0, 0.0, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationMixed) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Mix of zero and non-zero input. + static float input[kVectorSize * kBatchSize] = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.1, 0.2, 0.3, 0.4, // batch 1 + }; + std::vector<float> output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector<float> expected_output = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationSmallValue) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Mix of zero and non-zero input. + static float input[kVectorSize * kBatchSize] = { + 3e-5, -7e-6, -9e-5, 1e-6, // batch 0 + 4e-5, 9e-6, 2e-4, 0.0, // batch 1 + }; + std::vector<float> output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector<float> expected_output = { + 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0 + -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 3b296f024f..9f6e74a267 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -720,12 +720,12 @@ struct ConcatenationParams { struct ComparisonParams { // uint8 inference params. int left_shift; - int32 input0_offset; - int32 input0_multiplier; - int input0_shift; int32 input1_offset; int32 input1_multiplier; int input1_shift; + int32 input2_offset; + int32 input2_multiplier; + int input2_shift; // Shape dependent / common to inference types. bool is_broadcast; }; @@ -889,6 +889,7 @@ struct SplitParams { // Graphs that split into, say, 2000 nodes are encountered. The indices in // OperatorEdges are of type uint16. uint16 num_split; + int16 axis; }; struct SqueezeParams { diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc new file mode 100644 index 0000000000..1bbea67b93 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc @@ -0,0 +1,1316 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Layer Normalization LSTM op that applies normalization by mean and standard +// deviation to the activation of the LSTM layers. Please see +// https://arxiv.org/abs/1607.06450 for details. +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace layer_norm_lstm { + +// Struct to hold Layer Norm LSTM option data. +struct OpData { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + int scratch_tensor_index; +}; + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kInputLayerNormWeightsTensor = 12; +constexpr int kForgetLayerNormWeightsTensor = 13; +constexpr int kCellLayerNormWeightsTensor = 14; +constexpr int kOutputLayerNormWeightsTensor = 15; + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 16; // Optional +constexpr int kForgetGateBiasTensor = 17; +constexpr int kCellGateBiasTensor = 18; +constexpr int kOutputGateBiasTensor = 19; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 20; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 21; // Optional + +// State tensors. +constexpr int kInputActivationStateTensor = 22; +constexpr int kInputCellStateTensor = 23; + +// Output tensor. +constexpr int kOutputTensor = 0; + +// Total number of scratch tensors for hybrid Op. +constexpr int kTensorsToAdd = 7; + +// Small float to avoid divergence during calculation of deviation. +const float kLayerNormEpsilon = 1e-8; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + + // Turn custom option data into flexbuffer map format. + const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + // Get activation function, cell_clip and proj_clip from the flexbuffer. + // TODO(b/113824099): make activation more generic. + assert(m["fused_activation_function"].ToString() == "TANH"); + data->activation = kTfLiteActTanh; + data->cell_clip = m["cell_clip"].AsFloat(); + data->proj_clip = m["proj_clip"].AsFloat(); + + // Populate scratch_tensor_index. + context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd, + &data->scratch_tensor_index); + return data; +} + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, op_data->cell_clip >= 0); + TF_LITE_ENSURE(context, op_data->proj_clip >= 0); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Making sure layer norm weights are not null and have the right dimension. + const TfLiteTensor* input_layer_norm_weights = + GetInput(context, node, kInputLayerNormWeightsTensor); + TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* forget_layer_norm_weights = + GetInput(context, node, kForgetLayerNormWeightsTensor); + TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* cell_layer_norm_weights = + GetInput(context, node, kCellLayerNormWeightsTensor); + TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* output_layer_norm_weights = + GetInput(context, node, kOutputLayerNormWeightsTensor); + TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + const bool projection_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projection_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state tensors based on the sizes of the input tensors. +// Allocate a temporary scratch tensor. Also check that the sizes of the input +// tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 24); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); + + // Get the pointer to output, activation_state and cell_state tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + const TfLiteTensor* cell_state = + GetInput(context, node, kInputCellStateTensor); + + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + // Resize the output tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + + TfLiteIntArrayFree(node->temporaries); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + scratch_buffer->type = input->type; + scratch_buffer->allocation_type = kTfLiteArenaRw; + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + if (use_cifg) { + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + } else { + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // activation_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6); + recovered_weights->type = kTfLiteFloat32; + recovered_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1); + recovered_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_weights, + recovered_weights_size)); + } + } + return kTfLiteOk; +} + +void LayerNormLstmStep( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, + const float* input_layer_norm_weight_ptr, + const float* forget_layer_norm_weight_ptr, + const float* cell_layer_norm_weight_ptr, + const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, float cell_clip, float proj_clip, + const TfLiteFusedActivation& activation, int n_batch, int n_cell, + int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr, + float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + + // Initialize scratch buffers with 0. + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, + n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, + n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, + cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, + n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, + output_ptr_batch, /*result_stride=*/1); + if (proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, + output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + +void LayerNormLstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, + const float* input_layer_norm_weight_ptr, + const float* forget_layer_norm_weight_ptr, + const float* cell_layer_norm_weight_ptr, + const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + int n_batch, int n_cell, int n_input, int n_output, + float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + + // Initialize scratch buffers with 0. + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, + n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, + n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, + cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, + n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, + output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + +// The LayerNormLSTM Op engine. +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_weights, + const TfLiteTensor* forget_layer_norm_weights, + const TfLiteTensor* cell_layer_norm_weights, + const TfLiteTensor* output_layer_norm_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; + const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; + const float* input_to_output_weights_ptr = input_to_output_weights->data.f; + const float* recurrent_to_forget_weights_ptr = + recurrent_to_forget_weights->data.f; + const float* recurrent_to_cell_weights_ptr = + recurrent_to_cell_weights->data.f; + const float* recurrent_to_output_weights_ptr = + recurrent_to_output_weights->data.f; + const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f; + const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; + const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; + const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + LayerNormLstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + cell_to_input_weights_ptr, cell_to_forget_weights_ptr, + cell_to_output_weights_ptr, input_layer_norm_weight_ptr, + forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr, + output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, + n_input, n_output, activation_state_ptr, cell_state_ptr, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_weights, + const TfLiteTensor* forget_layer_norm_weights, + const TfLiteTensor* cell_layer_norm_weights, + const TfLiteTensor* output_layer_norm_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights, + TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast<int8_t*>(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f; + const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; + const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; + const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast<int8_t*>(input_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_weights_ptr = recovered_weights->data.f; + + LayerNormLstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr, + cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_weights_ptr, quantized_input_ptr, + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_layer_norm_weights = + GetInput(context, node, kInputLayerNormWeightsTensor); + const TfLiteTensor* forget_layer_norm_weights = + GetInput(context, node, kForgetLayerNormWeightsTensor); + const TfLiteTensor* cell_layer_norm_weights = + GetInput(context, node, kCellLayerNormWeightsTensor); + const TfLiteTensor* output_layer_norm_weights = + GetInput(context, node, kOutputLayerNormWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* activation_state = + &context->tensors[node->inputs->data[kInputActivationStateTensor]]; + TfLiteTensor* cell_state = + &context->tensors[node->inputs->data[kInputCellStateTensor]]; + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_weights, + forget_layer_norm_weights, cell_layer_norm_weights, + output_layer_norm_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, op_data->cell_clip, + op_data->proj_clip, op_data->activation, scratch_buffer, + activation_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_layer_norm_weights, forget_layer_norm_weights, + cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, op_data->cell_clip, op_data->proj_clip, + op_data->activation, scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_weights, input_quantized, + activation_state_quantized, cell_state_quantized, activation_state, + cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +} // namespace layer_norm_lstm + +TfLiteRegistration* Register_LAYER_NORM_LSTM() { + static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free, + layer_norm_lstm::Prepare, + layer_norm_lstm::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc new file mode 100644 index 0000000000..abc229f85a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc @@ -0,0 +1,664 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Layer Norm LSTM op. + +#include <memory> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_LAYER_NORM_LSTM(); + +namespace { + +using ::testing::ElementsAreArray; + +class LayerNormLSTMOpModel : public SingleOpModel { + public: + LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector<std::vector<int>>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(weight_type); + } + + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(weight_type); + } + + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + input_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + output_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + // Adding the 2 state tensors. + output_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + + output_ = AddOutput(TensorType_FLOAT32); + + // Set up and pass in custom options using flexbuffer. + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("cell_clip", cell_clip); + fbb.Int("proj_clip", proj_clip); + fbb.String("fused_activation_function", "TANH"); + }); + fbb.Finish(); + SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list<float> f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list<float> f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list<float> f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list<float> f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list<float> f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list<float> f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list<float> f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list<float> f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list<float> f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list<float> f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list<float> f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(input_layer_norm_weights_, f); + } + + void SetForgetLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(forget_layer_norm_weights_, f); + } + + void SetCellLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(cell_layer_norm_weights_, f); + } + + void SetOutputLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(output_layer_norm_weights_, f); + } + + void SetInputGateBias(std::initializer_list<float> f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list<float> f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list<float> f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list<float> f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list<float> f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list<float> f) { + PopulateTensor(projection_bias_, f); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast<float*>(begin), + const_cast<float*>(end)); + } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_layer_norm_weights_; + int forget_layer_norm_weights_; + int cell_layer_norm_weights_; + int output_layer_norm_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_state_; + int cell_state_; + + int output_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel { + public: + HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector<std::vector<int>>& input_shapes) + : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, + use_peephole, use_projection_weights, + use_projection_bias, cell_clip, proj_clip, + input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(input_layer_norm_weights_, f); + } + + void SetForgetLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(forget_layer_norm_weights_, f); + } + + void SetCellLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(cell_layer_norm_weights_, f); + } + + void SetOutputLayerNormWeights(std::initializer_list<float> f) { + PopulateTensor(output_layer_norm_weights_, f); + } + + void SetProjectionWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLayerNormLstmTest : public ::testing::Test { + protected: + // Weights of the Layer Norm LSTM model. Some are optional. + std::initializer_list<float> input_to_input_weights_; + std::initializer_list<float> input_to_cell_weights_; + std::initializer_list<float> input_to_forget_weights_; + std::initializer_list<float> input_to_output_weights_; + std::initializer_list<float> input_gate_bias_; + std::initializer_list<float> cell_gate_bias_; + std::initializer_list<float> forget_gate_bias_; + std::initializer_list<float> output_gate_bias_; + std::initializer_list<float> recurrent_to_input_weights_; + std::initializer_list<float> recurrent_to_cell_weights_; + std::initializer_list<float> recurrent_to_forget_weights_; + std::initializer_list<float> recurrent_to_output_weights_; + std::initializer_list<float> cell_to_input_weights_; + std::initializer_list<float> cell_to_forget_weights_; + std::initializer_list<float> cell_to_output_weights_; + std::initializer_list<float> input_layer_norm_weights_; + std::initializer_list<float> forget_layer_norm_weights_; + std::initializer_list<float> cell_layer_norm_weights_; + std::initializer_list<float> output_layer_norm_weights_; + std::initializer_list<float> projection_weights_; + + // Layer Norm LSTM input is stored as num_batch x num_inputs vector. + std::vector<std::vector<float>> layer_norm_lstm_input_; + + // Compares output up to tolerance to the result of the layer_norm_lstm given + // the input. + void VerifyGoldens(const std::vector<std::vector<float>>& input, + const std::vector<std::vector<float>>& output, + LayerNormLSTMOpModel* layer_norm_lstm, + float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = layer_norm_lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(), + batch_start, batch_end); + } + + layer_norm_lstm->Invoke(); + + const int num_outputs = layer_norm_lstm->num_outputs(); + std::vector<float> expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(layer_norm_lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest + : public BaseLayerNormLstmTest { + void SetUp() override { + input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, + 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, + -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, + -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + layer_norm_lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_layer_norm_weight tensor + {n_cell}, // forget_layer_norm_weight tensor + {n_cell}, // cell_layer_norm_weight tensor + {n_cell}, // output_layer_norm_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); + layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); + layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); + layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector<std::vector<float>> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459231, 0.155278, 0.0837377, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403912, 0.139963, 0.072681, // seq 1 + 0.00752706, 0.161903, 0.0561371, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + HybridLayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + HybridLayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_layer_norm_weight tensor + {n_cell}, // forget_layer_norm_weight tensor + {n_cell}, // cell_layer_norm_weight tensor + {n_cell}, // output_layer_norm_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); + layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); + layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); + layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + const std::vector<std::vector<float>> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244576, 0.127847, -0.00181765, // seq 0 + 0.0137518, 0.140892, 0.0402234, // seq 1 + -0.0048839, 0.155096, 0.0840309, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00728636, 0.0843957, 0.0634786, // seq 0 + -0.00448382, 0.139278, 0.0737372, // seq 1 + 0.00734616, 0.161793, 0.0560238, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 55bcf3b533..3bce05353d 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { op_context.constant_values->type); } - // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, op_context.dims, 4); + // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D. + TF_LITE_ENSURE(context, op_context.dims <= 4); // Exit early if paddings is a non-const tensor. Set output tensor to // dynamic so output size can be determined in Eval. @@ -134,21 +134,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { after_padding.push_back(paddings_data[idx * 2 + 1]); } -#define TF_LITE_PAD(type, scalar, pad_value) \ - TF_LITE_ENSURE_EQ(context, before_padding.size(), 4); \ - TF_LITE_ENSURE_EQ(context, after_padding.size(), 4); \ - tflite::PadParams op_params; \ - op_params.left_padding_count = 4; \ - op_params.right_padding_count = 4; \ - for (int i = 0; i < 4; ++i) { \ - op_params.left_padding[i] = before_padding[3 - i]; \ - op_params.right_padding[i] = after_padding[3 - i]; \ - } \ - const scalar pad_value_copy = pad_value; \ - \ - type::Pad(op_params, GetTensorShape(op_context.input), \ - GetTensorData<scalar>(op_context.input), &pad_value_copy, \ - GetTensorShape(op_context.output), \ +#define TF_LITE_PAD(type, scalar, pad_value) \ + TF_LITE_ENSURE(context, before_padding.size() <= 4); \ + TF_LITE_ENSURE(context, after_padding.size() <= 4); \ + tflite::PadParams op_params; \ + op_params.left_padding_count = before_padding.size(); \ + op_params.right_padding_count = after_padding.size(); \ + for (int i = 0; i < op_context.dims; ++i) { \ + op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \ + op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \ + } \ + const scalar pad_value_copy = pad_value; \ + \ + type::Pad(op_params, GetTensorShape(op_context.input), \ + GetTensorData<scalar>(op_context.input), &pad_value_copy, \ + GetTensorShape(op_context.output), \ GetTensorData<scalar>(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: { diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc index f8b9064fbb..f663899713 100644 --- a/tensorflow/contrib/lite/kernels/pad_test.cc +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) { PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, {TensorType_FLOAT32}), - "dims != 4"); + "dims <= 4"); } TEST(PadOpTest, UnequalDimensions) { @@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } +TEST(PadOpTest, SimpleConst1DTest) { + PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2}, + {TensorType_FLOAT32}); + m.SetInput({2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + TEST(PadOpTest, SimpleDynamicTest) { PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, {TensorType_FLOAT32}); @@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) { {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0, {TensorType_FLOAT32}), - "dims != 4"); + "dims <= 4"); } TEST(PadV2OpTest, UnequalDimensions) { diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 7b859dc332..c66959fdf4 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,8 +22,10 @@ namespace ops { namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); +TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); +TfLiteRegistration* Register_RELU_1(); } // namespace custom @@ -247,6 +249,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); + AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc new file mode 100644 index 0000000000..abafee2d57 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/relu1.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace relu1 { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +// This is derived from lite/kernels/activations.cc. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + const int elements = NumElements(input); + const float* in = input->data.f; + const float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; ++in, ++out) { + *out = std::min(std::max(0.f, *in), 1.f); + } + return kTfLiteOk; +} + +} // namespace relu1 + +TfLiteRegistration* Register_RELU_1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + relu1::Prepare, relu1::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc new file mode 100644 index 0000000000..c1e0149c20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/relu1_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_RELU_1(); + +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + explicit BaseActivationsOpModel(const TensorData& input) { + input_ = AddInput(input); + output_ = AddOutput({input.type, {}}); + flexbuffers::Builder fbb; + fbb.Map([&]() {}); + fbb.Finish(); + SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list<float> data) { + PopulateTensor(input_, data); + } + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } +}; + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, 0.0, 0.2, 0.0, // + 0.3, 0.0, 1.0, 0.0, // + })); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 57134ccd15..32f02a4f6c 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1679,6 +1679,7 @@ def make_pad_tests(zip_path): # TODO(nupurgarg): Add test for tf.uint8. test_parameters = [ + # 4D: { "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], @@ -1686,13 +1687,20 @@ def make_pad_tests(zip_path): [0, 0], [2, 3]]], "constant_paddings": [True, False], }, - # Non-4D use case. + # 2D: { "dtype": [tf.int32, tf.int64, tf.float32], - "input_shape": [[1, 2], [0, 1, 2]], + "input_shape": [[1, 2]], "paddings": [[[0, 1], [2, 3]]], "constant_paddings": [True, False], }, + # 1D: + { + "dtype": [tf.int32], + "input_shape": [[1]], + "paddings": [[[1, 2]]], + "constant_paddings": [False], + }, ] def build_graph(parameters): @@ -1730,6 +1738,7 @@ def make_padv2_tests(zip_path): # TODO(nupurgarg): Add test for tf.uint8. test_parameters = [ + # 4D: { "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], @@ -1738,14 +1747,22 @@ def make_padv2_tests(zip_path): "constant_paddings": [True, False], "constant_values": [0, 2], }, - # Non-4D use case. + # 2D: { "dtype": [tf.int32, tf.int64, tf.float32], - "input_shape": [[1, 2], [0, 1, 2]], + "input_shape": [[1, 2]], "paddings": [[[0, 1], [2, 3]]], "constant_paddings": [True, False], "constant_values": [0, 2], }, + # 1D: + { + "dtype": [tf.int32], + "input_shape": [[1]], + "paddings": [[[0, 1]]], + "constant_paddings": [False], + "constant_values": [0, 2], + }, ] def build_graph(parameters): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 37c7ae0e1c..349aa5a3b4 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -58,12 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map<string, string> kBrokenTests = { - // Pad and PadV2 only supports 4D tensors. - {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", - "70527055"}, - {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", - "70527055"}, - // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 84f71dc7a7..f14dbc258b 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -247,6 +247,10 @@ struct ParsedTocoFlags { Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false); Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64); Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true); + // WARNING: Experimental interface, subject to change + Arg<bool> allow_eager_ops = Arg<bool>(false); + // WARNING: Experimental interface, subject to change + Arg<bool> force_eager_ops = Arg<bool>(false); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 502de88f7c..3114fa93e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -63,6 +63,25 @@ bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) { return true; } +bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) { + auto& input = model->GetArray(op->inputs[0]); + if (input.minmax) { + const auto* minmax = input.minmax.get(); + if (minmax) { + return false; + } + } + auto& output = model->GetArray(op->outputs[0]); + if (output.minmax) { + const auto* minmax = model->GetArray(op->outputs[0]).minmax.get(); + if (minmax) { + input.GetOrCreateMinMax() = *minmax; + return true; + } + } + return false; +} + bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { // Do not early return if the output already has min/max: // we may still need to adjust the inputs min/max. @@ -366,6 +385,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForL2Normalization(model, op); break; + case OperatorType::kRelu: + // For any normalization other than batch norm, the quantizations ranges + // before and after relu are expected to be known. Having a quantization + // op before relu would reduce the number of bits of precision for the + // activation in half. So we deduce the range before relu from that after + // the relu. This would eliminate the need for two fake quantization nodes + // and would not reduce the bits of precision available for activation. + changed = HardcodeInputMinMaxFromOutput(model, op); + break; + case OperatorType::kConcatenation: changed = HardcodeMinMaxForConcatenation(model, op); break; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cb6da21039..9bc23c4b3c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -2061,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef( } Model* model = new Model; - const internal::ConverterMapType& converter_map = - internal::GetTensorFlowNodeConverterMap(); + internal::ConverterMapType converter_map; + + // This is used for the TFLite "Full Eager Mode" conversion. All the ops are + // imported as `TensorFlowUnsupportedOperator`, and later all these ops are + // converted to TFLite Eager ops. + if (!tf_import_flags.import_all_ops_as_unsupported) { + converter_map = internal::GetTensorFlowNodeConverterMap(); + } for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h index 2177872334..7db23f2d44 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.h +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -27,6 +27,11 @@ struct TensorFlowImportFlags { // If true, control dependencies will be dropped immediately // during the import of the TensorFlow GraphDef. bool drop_control_dependency = false; + + // Do not recognize any op and import all ops as + // `TensorFlowUnsupportedOperator`. This is used to populated with the + // `force_eager_ops` flag. + bool import_all_ops_as_unsupported = false; }; std::unique_ptr<Model> ImportTensorFlowGraphDef( diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index c79469f59b..fee10b1dff 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,12 +49,21 @@ namespace { details::OperatorKey GetOperatorKey( const ::toco::Operator& op, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops) { string custom_code; if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast<const TensorFlowUnsupportedOperator&>(op); - custom_code = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_eager_ops) { + custom_code = string(::tflite::kEagerCustomCodePrefix) + + unsupported_op.tensorflow_op; + } else { + custom_code = unsupported_op.tensorflow_op; + } } int version = 1; if (ops_by_type.count(op.type) != 0) { @@ -91,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops) { // First find a list of unique operator types. std::set<OperatorKey> keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op, ops_by_type)); + keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops)); } // Now assign indices to them and fill in the map. int index = 0; @@ -189,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( const Model& model, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set<string>* error_summary) { + std::set<string>* error_summary, const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map<string, BuiltinOperator> builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -205,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes( std::map<int, Offset<OperatorCode>> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); + const details::OperatorKey operator_key = + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops); int op_index = operators_map.at(operator_key); int op_version = operator_key.version; @@ -252,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators( const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, const details::OperatorsMap& operators_map, const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, - std::set<int32_t>* variable_tensor_indices) { + std::set<int32_t>* variable_tensor_indices, const ExportParams& params) { variable_tensor_indices->clear(); // The operators are in execution order, so we just follow tf.mini order. @@ -269,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); + int op_index = operators_map.at( + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops)); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -320,16 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers( return builder->CreateVector(buffer_vector); } -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents) { - const auto ops_by_type = BuildOperatorByTypeMap(); - Export(model, allow_custom_ops, quantize_weights, output_file_contents, - ops_by_type); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params) { + const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops); + Export(model, output_file_contents, params, ops_by_type); } void Export( - const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents, + const Model& model, string* output_file_contents, + const ExportParams& params, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); @@ -337,7 +348,8 @@ void Export( details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map, ops_by_type); + details::LoadOperatorsMap(model, &operators_map, ops_by_type, + params.allow_eager_ops); std::vector<const Array*> buffers_to_write; Array empty_array; @@ -345,7 +357,7 @@ void Export( std::set<string> error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &error_summary); + &builder, &error_summary, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -355,7 +367,7 @@ void Export( "for --std_values and --mean_values."; } } - if (!allow_custom_ops && !error_summary.empty()) { + if (!params.allow_custom_ops && !error_summary.empty()) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they // compose the list. Both ops are removed during graph transformations. // However, if an op is unimplemented earlier in the model, the graph @@ -383,7 +395,7 @@ void Export( std::set<int32_t> variable_tensor_indices; auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, - &builder, &variable_tensor_indices); + &builder, &variable_tensor_indices, params); auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, variable_tensor_indices); @@ -402,7 +414,7 @@ void Export( builder.CreateVector(subgraphs), description, buffers); ::tflite::FinishModelBuffer(builder, new_model_location); - if (quantize_weights) { + if (params.quantize_weights) { // Call the quantize_weights tool. LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. " "dump_graphviz will only output the model before this " diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 915d5dd3d6..b070a38768 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -23,22 +23,54 @@ namespace toco { namespace tflite { +// The parameters for exporting a TFLite model. +struct ExportParams { + bool allow_custom_ops = false; + bool allow_eager_ops = false; + bool quantize_weights = false; +}; + // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the // result in the given string. -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params); + +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, string* output_file_contents, + const ExportParams& params, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); -// This if backward-compatibility. +// This is for backward-compatibility. // TODO(ycling): Remove the deprecated entry functions. -inline void Export(const Model& model, string* output_file_contents) { - Export(model, true, false, output_file_contents); +inline void Export(const Model& model, bool allow_custom_ops, + bool quantize_weights, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params); } -// Export API with custom TFLite operator mapping. -void Export( +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export( const Model& model, bool allow_custom_ops, bool quantize_weights, string* output_file_contents, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params, ops_by_type); +} + +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export(const Model& model, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = true; + Export(model, output_file_contents, params); + Export(model, true, false, output_file_contents); +} namespace details { @@ -88,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops); } // namespace details } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 4994ea30de..8d4d197c46 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -105,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { details::OperatorsMap operators; const auto ops_by_type = BuildOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + // TODO(ycling): Add a test for allow_eager_ops. + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); @@ -253,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); @@ -264,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); @@ -276,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(2, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a314c8d53a..eb0f7c443a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1149,7 +1149,9 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, class TensorFlowUnsupported : public BaseOperator { public: - using BaseOperator::BaseOperator; + TensorFlowUnsupported(const string& name, OperatorType type, + bool allow_eager_ops) + : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {} Options Serialize(const Operator& op, flatbuffers::FlatBufferBuilder* builder) const override { @@ -1165,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator { std::unique_ptr<Operator> Deserialize( const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const override { + // Deserializing Eager ops doesn't work now. + // TODO(ycling): Revisit and decide if we should fix the flow for importing + // TFLite models with Eager ops. auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); if (custom_options) { auto flexbuffer_map = @@ -1185,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr<flexbuffers::Builder>(); } + if (allow_eager_ops_) { + fbb->Vector([&]() { + fbb->String(node_def.op()); + fbb->String(op.tensorflow_node_def); + }); + fbb->Finish(); + LOG(INFO) << "Writing eager op: " << node_def.op(); + return std::unique_ptr<flexbuffers::Builder>(fbb.release()); + } + bool has_valid_attr = false; size_t map_start = fbb->StartMap(); for (const auto& pair : node_def.attr()) { @@ -1285,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator { // custom ops. return 1; } + + private: + const bool allow_eager_ops_; }; namespace { // Build a vector containing all the known operators. -std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { +std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( + bool allow_eager_ops = false) { std::vector<std::unique_ptr<BaseOperator>> ops; using tensorflow::MakeUnique; // Builtin Operators. @@ -1400,8 +1419,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.push_back(MakeUnique<CTCBeamSearchDecoder>( "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); - ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED", - OperatorType::kUnsupported)); + ops.push_back(MakeUnique<TensorFlowUnsupported>( + "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. @@ -1474,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { } } // namespace -std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() { +std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( + bool allow_eager_ops) { std::map<OperatorType, std::unique_ptr<BaseOperator>> result; - std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); + std::vector<std::unique_ptr<BaseOperator>> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->type()] = std::move(op); } @@ -1485,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() { return result; } -std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() { +std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( + bool allow_eager_ops) { std::map<string, std::unique_ptr<BaseOperator>> result; - std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); + std::vector<std::unique_ptr<BaseOperator>> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->name()] = std::move(op); } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index d9ea23edf2..702fb28ea6 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -26,11 +26,15 @@ namespace tflite { class BaseOperator; // Return a map contained all know TF Lite Operators, keyed by their names. -std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(); +// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops) +// is ugly here. Consider refactoring. +std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( + bool allow_eager_ops = false); // Return a map contained all know TF Lite Operators, keyed by the type of // their tf.mini counterparts. -std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(); +std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( + bool allow_eager_ops = false); // These are the flatbuffer types for custom and builtin options. using CustomOptions = flatbuffers::Vector<uint8_t>; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index f83a290195..b6aebc0470 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -165,7 +165,13 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.post_training_quantize.default_value(), "Boolean indicating whether to quantize the weights of the " "converted float model. Model size will be reduced and there will " - "be latency improvements (at the cost of accuracy).")}; + "be latency improvements (at the cost of accuracy)."), + // WARNING: Experimental interface, subject to change + Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(), + parsed_flags.allow_eager_ops.default_value(), ""), + // WARNING: Experimental interface, subject to change + Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(), + parsed_flags.force_eager_ops.default_value(), "")}; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); if (asked_for_help) { @@ -260,6 +266,16 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); + READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone); + READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone); + + if (parsed_toco_flags.force_eager_ops.value() && + !parsed_toco_flags.allow_eager_ops.value()) { + // TODO(ycling): Consider to enforce `allow_eager_ops` when + // `force_eager_ops` is true. + LOG(WARNING) << "--force_eager_ops should always be used with " + "--allow_eager_ops."; + } // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index c1dd621429..53d60fed05 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 27. +// Next ID to use: 29. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -189,4 +189,17 @@ message TocoFlags { // model. Model size will be reduced and there will be latency improvements // (at the cost of accuracy). optional bool post_training_quantize = 26 [default = false]; + + // When enabled, unsupported ops will be converted to TFLite Eager ops. + // TODO(ycling): Consider to rename the following 2 flags and don't call it + // "Eager". + // `allow_eager_ops` should always be used with `allow_custom_ops`. + // WARNING: Experimental interface, subject to change + optional bool allow_eager_ops = 27 [default = false]; + + // When enabled, all TensorFlow ops will be converted to TFLite Eager + // ops directly. This will force `allow_eager_ops` to true. + // `force_eager_ops` should always be used with `allow_eager_ops`. + // WARNING: Experimental interface, subject to change + optional bool force_eager_ops = 28 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 7db7acb44d..a7c17156b1 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags, toco_flags.has_drop_control_dependency() ? toco_flags.drop_control_dependency() : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); + + tf_import_flags.import_all_ops_as_unsupported = + toco_flags.force_eager_ops(); + model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, input_file_contents); break; @@ -397,11 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model, case TENSORFLOW_GRAPHDEF: ExportTensorFlowGraphDef(model, output_file_contents); break; - case TFLITE: - toco::tflite::Export(model, allow_custom_ops, - toco_flags.post_training_quantize(), - output_file_contents); - break; + case TFLITE: { + toco::tflite::ExportParams params; + + // Always allow custom ops when eager ops are allowed. + if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) { + params.allow_eager_ops = true; + params.allow_custom_ops = true; + } else if (allow_custom_ops) { + params.allow_custom_ops = true; + } + + params.quantize_weights = toco_flags.post_training_quantize(); + + toco::tflite::Export(model, output_file_contents, params); + } break; case GRAPHVIZ_DOT: DumpGraphviz(model, output_file_contents); break; diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index e0ed7c7946..e5bb3c990a 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -42,10 +42,9 @@ typedef struct { bool eval_hybrid; } TensorInfo; -// The minimum number of elements a weights array must have to be quantized -// by this transformation. -// TODO(suharshs): Make this configurable. -const int kWeightsMinSize = 1024; +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; // Nudge min and max so that floating point 0 falls exactly on a quantized // value, returning the nudges scale and zero_point. @@ -158,42 +157,45 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) { // Returns a vector of TensorInfos for each input tensor of op that should be // quantized. -std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, - const OperatorT* op) { +std::vector<TensorInfo> GetQuantizableTensorsFromOperator( + const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements, + bool use_hybrid_evaluation) { SubGraphT* subgraph = model->subgraphs.at(0).get(); const BuiltinOperator op_code = model->operator_codes[op->opcode_index]->builtin_code; std::vector<TensorInfo> tensor_infos; - bool eval_hybrid = IsHybridEvaluationOp(op, op_code); + bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code); bool skipped_tensor = false; std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code); for (const int32_t op_input_idx : op_input_indices) { int32_t tensor_idx = op->inputs[op_input_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); // TODO(suharshs): Support shared weights, i.e. If two tensors share the // same weight array, things may break. (i.e. SSD object detection) - if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) { - LOG(INFO) << "Skipping quantization of tensor that is shared between " - "multiple multiple operations."; + if (!eval_hybrid && + CountTensorConsumers(model, subgraph, tensor_idx) != 1) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is shared between multiple multiple operations."; skipped_tensor = true; continue; } - TensorT* tensor = subgraph->tensors[tensor_idx].get(); - if (tensor->type != TensorType_FLOAT32) { - LOG(INFO) << "Skipping quantization of tensor that is not type float."; + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; skipped_tensor = true; continue; } const uint64_t num_elements = NumElements(tensor); - if (num_elements < kWeightsMinSize) { - LOG(INFO) << "Skipping quantization of tensor because it has fewer than " - << kWeightsMinSize << " elements (" << num_elements << ")."; + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; skipped_tensor = true; continue; } @@ -331,11 +333,10 @@ void MakeTensor(const string& name, const std::vector<int32_t>& shape, tensor->reset(tensor_raw); } -} // namespace - -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, - const Model* input_model, - bool use_hybrid_evaluation) { +TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + bool use_hybrid_evaluation, + uint64_t weights_min_num_elements) { std::unique_ptr<ModelT> model; model.reset(input_model->UnPack()); @@ -352,11 +353,11 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - std::vector<TensorInfo> tensor_infos = - GetQuantizableTensorsFromOperator(model.get(), op); + std::vector<TensorInfo> tensor_infos = GetQuantizableTensorsFromOperator( + model.get(), op, weights_min_num_elements, use_hybrid_evaluation); for (const TensorInfo& tensor_info : tensor_infos) { - if (use_hybrid_evaluation && tensor_info.eval_hybrid) { + if (tensor_info.eval_hybrid) { // Quantize the tensor. TF_LITE_ENSURE_STATUS( SymmetricQuantizeTensor(model.get(), tensor_info.tensor)); @@ -399,9 +400,32 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, return kTfLiteOk; } +} // namespace + +namespace internal { +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + bool use_hybrid_evaluation) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation, + kWeightsMinNumElementsDefault); +} +} // namespace internal + +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements) { + return QuantizeWeightsInternal(builder, input_model, true, + weights_min_num_elements); +} + TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model) { - return QuantizeWeights(builder, input_model, true); + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + return QuantizeWeightsInternal(builder, input_model, true, + kWeightsMinNumElementsDefault); } } // namespace optimize diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h index 3743c0ce53..706f10b87b 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h @@ -25,6 +25,8 @@ namespace tflite { namespace optimize { // Quantizes input_model and populates the provided builder with the new model. +// By default only weights tensors weight more than 1024 elements will be +// quantized. // // A tflite::Model can be obtained from the builder with: // const uint8_t* buffer = builder->GetBufferPointer(); @@ -32,11 +34,22 @@ namespace optimize { TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model); -// Same as above, but if use_hybrid_evaluation is false, will disable using -// hybrid eval for operations that support it. +// Same as above, but only weights with greater than or equal +// weights_min_num_elements elements will be quantized. +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements); + +namespace internal { +// If use_hybrid_evaluation is false, will disable using hybrid eval for +// operations that support it. +// +// We use this internal QuantizeWeights call to test models with hybrid +// evaluation disabled. TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, bool use_hybrid_evaluation); +} // namespace internal } // namespace optimize } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc index efaf9929e9..387b3471c2 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc @@ -76,7 +76,8 @@ class QuantizeWeightsTest : public ::testing::Test { void CheckWeights(const Model* input_model_packed, const Model* output_model_packed, - bool use_hybrid_evaluation) { + bool use_hybrid_evaluation, + uint64_t weights_min_num_elements = 1024) { std::unique_ptr<ModelT> input_model; input_model.reset(input_model_packed->UnPack()); @@ -113,8 +114,9 @@ class QuantizeWeightsTest : public ::testing::Test { int tensor_size = GetElementsNum(tensor); // If the tensor_size is less than 1024 we expect the tensor to remain // unquantized. - if (tensor_size < 1024) { - ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + if (tensor_size < weights_min_num_elements) { + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) + << tensor->name << " of type " << tensor->type; const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); // The weight tensor should not come from a dequantize op. ASSERT_TRUE(preceding_op == nullptr); @@ -183,7 +185,7 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { flatbuffers::FlatBufferBuilder builder; // Disable hybrid evaluation. - EXPECT_EQ(QuantizeWeights(&builder, input_model, false), kTfLiteOk); + EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -191,6 +193,26 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { CheckWeights(input_model, output_model, false); } +TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) { + string model_path = + "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "mobilenet_v1_0.25_128.tflite"; + std::unique_ptr<FlatBufferModel> input_fb = + FlatBufferModel::BuildFromFile(model_path.data()); + const Model* input_model = input_fb->GetModel(); + + flatbuffers::FlatBufferBuilder builder; + // Make weights_min_size sufficiently large such that no quantization should + // happen, i.e. the original model is the same size as the old one. + const uint64_t kWeightsMinNumElements = 1000000; + EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements), + kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + CheckWeights(input_model, output_model, true, kWeightsMinNumElements); +} + // TODO(suharshs): Add tests that run the resulting model. } // namespace diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index bbafd59aae..6c203e5519 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -128,12 +128,14 @@ class ElasticAverageCustomGetter(object): = list(global_center_variable)[i] return local_var else: - return getter( - name, - trainable=trainable, - collections=collections, - *args, - **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) + class ElasticAverageOptimizer(optimizer.Optimizer): diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index 72117c1e81..f026f437dc 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -25,9 +25,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import adam @@ -46,7 +48,12 @@ class LazyAdamOptimizer(adam.AdamOptimizer): may lead to different empirical results. """ - def _apply_sparse(self, grad, var): + def _apply_sparse_shared(self, + grad, + var, + indices, + scatter_update, + scatter_sub): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) @@ -58,23 +65,51 @@ class LazyAdamOptimizer(adam.AdamOptimizer): # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t = state_ops.scatter_update(m, grad.indices, - beta1_t * array_ops.gather(m, grad.indices) + - (1 - beta1_t) * grad.values, - use_locking=self._use_locking) + m_t = scatter_update(m, indices, + beta1_t * array_ops.gather(m, indices) + + (1 - beta1_t) * grad) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t = state_ops.scatter_update(v, grad.indices, - beta2_t * array_ops.gather(v, grad.indices) + - (1 - beta2_t) * math_ops.square(grad.values), - use_locking=self._use_locking) + v_t = scatter_update(v, indices, + beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, grad.indices) - v_t_slice = array_ops.gather(v_t, grad.indices) + m_t_slice = array_ops.gather(m_t, indices) + v_t_slice = array_ops.gather(v_t, indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = state_ops.scatter_sub(var, grad.indices, - lr * m_t_slice / denominator_slice, - use_locking=self._use_locking) + var_update = scatter_sub(var, indices, + lr * m_t_slice / denominator_slice) return control_flow_ops.group(var_update, m_t, v_t) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + self._scatter_update, + self._scatter_sub) + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, + self._resource_scatter_update, + self._resource_scatter_sub) + + # Utility functions for updating resource or non-resource variables. + def _scatter_update(self, x, i, v): + return state_ops.scatter_update( + x, i, v, use_locking=self._use_locking) + + def _scatter_sub(self, x, i, v): + return state_ops.scatter_sub( + x, i, v, use_locking=self._use_locking) + + def _resource_scatter_update(self, x, i, v): + update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v) + with ops.control_dependencies([update_op]): + return x.value() + + def _resource_scatter_sub(self, x, i, v): + sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) + with ops.control_dependencies([sub_op]): + return x.value() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index dc4c462ce4..d3e9e89502 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -51,7 +52,7 @@ def adam_update_numpy(param, class AdamOptimizerTest(test.TestCase): - def testSparse(self): + def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. @@ -61,8 +62,12 @@ class AdamOptimizerTest(test.TestCase): var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -94,6 +99,12 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + def testSparseDevicePlacement(self): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index b6b10e500b..746df77ba2 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object): self._local_2_global[local_var] = global_variable return local_var else: - return getter(name, trainable, collections, *args, **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) class ModelAverageOptimizer(optimizer.Optimizer): diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py index 3acd940268..b1fc50a21f 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py @@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers): var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") - with ops.device("/job:worker/task:" + str(worker_id)): - if worker_id == 0: - grads_0 = constant_op.constant(-1.0) - grads_1 = constant_op.constant(-1.0) - else: - grads_0 = constant_op.constant(-2.0) - grads_1 = constant_op.constant(-2.0) - sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) - opt = model_average_optimizer.ModelAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - ma_custom_getter=ma_coustom, - is_chief=is_chief, - interval_steps=steps) - train_op = [ - opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], - global_step) - ] - easgd_hook = opt.make_session_run_hook() + with ops.device("/job:worker/task:" + str(worker_id)): + if worker_id == 0: + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + else: + grads_0 = constant_op.constant(-2.0) + grads_1 = constant_op.constant(-2.0) + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = model_average_optimizer.ModelAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + ma_custom_getter=ma_coustom, + is_chief=is_chief, + interval_steps=steps) + train_op = [ + opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], + global_step) + ] + ma_hook = opt.make_session_run_hook() # Creates MonitoredSession sess = training.MonitoredTrainingSession( - workers[worker_id].target, hooks=[easgd_hook]) + workers[worker_id].target, hooks=[ma_hook]) sessions.append(sess) graphs.append(graph) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 15ce9d1ce7..be0306cb07 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name class RNNCellTest(test.TestCase): def testLinear(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(1.0)): x = array_ops.zeros([1, 2]) @@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(variables_lib.trainable_variables()), 2) def testBasicRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellNotTrainable(self): - with self.test_session() as sess: + with self.cached_session() as sess: def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False @@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testIndRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.156736, 0.156736]]) def testIndyGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.155127, 0.157328]]) def testSRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.509682, 0.509682]]) def testSRUCellWithDiffSize(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellDimension0Error(self): """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase): }) def testBasicLSTMCellStateTupleType(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase): self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(res), 2) def testLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testLSTMCellVariables(self): - with self.test_session(): + with self.cached_session(): num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase): "root/lstm_cell/projection/kernel") def testLSTMCellLayerNorm(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 num_proj = 3 batch_size = 1 @@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase): rnn_cell_impl.DropoutWrapper, rnn_cell_impl.ResidualWrapper, lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - with self.test_session(): - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) def testOutputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.231907, 0.231907]]) def testInputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) def testResidualWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[2], res[3]) def testResidualWrapperWithSlice(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 5]) @@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase): self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) def testEmbeddingWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 1], dtype=dtypes.int32) @@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.17139, 0.17139]]) def testEmbeddingWrapperWithDynamicRnn(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root"): inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) @@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase): sess.run(outputs) def testMultiRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase): time_steps=None, parallel_iterations=None, **kwargs): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): if batch_size is None and time_steps is None: diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD index 3c616c555b..ea4d41d43b 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD @@ -30,6 +30,7 @@ cc_library( hdrs = ["signature_def_utils.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -42,6 +43,7 @@ tf_cc_test( srcs = ["signature_def_utils_test.cc"], deps = [ ":signature_def_utils", + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc index a45908d272..e87e497e5f 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" @@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description, *value = &it->second; return Status::OK(); } + +// Looks up the TensorInfo for the given key in the given map and verifies that +// its datatype matches the given correct datatype. +bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map, + const string& key, DataType correct_dtype) { + const TensorInfo* tensor_info; + const Status& status = FindInProtobufMap("", map, key, &tensor_info); + if (!status.ok()) { + return false; + } + if (tensor_info->dtype() != correct_dtype) { + return false; + } + return true; +} + +bool IsValidPredictSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kPredictMethodName) { + return false; + } + if (signature_def.inputs().empty()) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + return true; +} + +bool IsValidRegressionSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kRegressMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs, + DT_STRING)) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs, + DT_FLOAT)) { + return false; + } + return true; +} + +bool IsValidClassificationSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kClassifyMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs, + DT_STRING)) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + for (auto const& output : signature_def.outputs()) { + const string& key = output.first; + const TensorInfo& tensor_info = output.second; + if (key == kClassifyOutputClasses) { + if (tensor_info.dtype() != DT_STRING) { + return false; + } + } else if (key == kClassifyOutputScores) { + if (tensor_info.dtype() != DT_FLOAT) { + return false; + } + } else { + return false; + } + } + return true; +} + } // namespace Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def, @@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def, return Status::OK(); } +bool IsValidSignature(const SignatureDef& signature_def) { + return IsValidClassificationSignature(signature_def) || + IsValidRegressionSignature(signature_def) || + IsValidPredictSignature(signature_def); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h index b732cdd41e..bb24faa989 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h @@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def, Status FindOutputTensorNameByKey(const SignatureDef& signature_def, const string& tensor_info_key, string* name); +// Determine whether a SignatureDef can be served by TensorFlow Serving. +bool IsValidSignature(const SignatureDef& signature_def); + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc index a063e95696..c743112ce0 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -22,7 +23,7 @@ limitations under the License. namespace tensorflow { -class SignatureDefUtilsTest : public ::testing::Test { +class FindByKeyTest : public ::testing::Test { protected: MetaGraphDef MakeSampleMetaGraphDef() { MetaGraphDef result; @@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test { return result; } + void SetInputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_inputs())[key].set_name(name); + } + + void SetOutputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_outputs())[key].set_name(name); + } + SignatureDef MakeSampleSignatureDef() { SignatureDef result; result.set_method_name(kMethodName); - (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name); - (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name); - (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name); - (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name); + SetInputNameForKey(kInput1Key, kInput1Name, &result); + SetInputNameForKey(kInput2Key, kInput2Name, &result); + SetOutputNameForKey(kOutput1Key, kOutput1Name, &result); + SetOutputNameForKey(kOutput2Key, kOutput2Name, &result); return result; } @@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test { const string kOutput2Name = "output_two"; }; -TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { +TEST_F(FindByKeyTest, FindSignatureDefByKey) { const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef(); const SignatureDef* signature_def; // Succeeds for an existing signature. @@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { .ok()); } -TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { +TEST_F(FindByKeyTest, FindInputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing input. @@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } -TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { +TEST_F(FindByKeyTest, FindOutputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing output. @@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } +class IsValidSignatureTest : public ::testing::Test { + protected: + void SetInputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_inputs())[key].set_dtype(dtype); + } + + void SetOutputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_outputs())[key].set_dtype(dtype); + } + + void EraseOutputKey(const string& key) { + (*signature_def_.mutable_outputs()).erase(key); + } + + void ExpectInvalidSignature() { + EXPECT_FALSE(IsValidSignature(signature_def_)); + } + + void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); } + + SignatureDef signature_def_; +}; + +TEST_F(IsValidSignatureTest, IsValidPredictSignature) { + signature_def_.set_method_name("not_kPredictMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kPredictMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kPredictInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kPredictOutputs, DT_STRING); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidRegressionSignature) { + signature_def_.set_method_name("not_kRegressMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kRegressMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kRegressInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_STRING); + // Incorrect data type + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidClassificationSignature) { + signature_def_.set_method_name("not_kClassifyMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kClassifyMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kClassifyInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey("invalidKey", DT_FLOAT); + // Invalid key + ExpectInvalidSignature(); + + EraseOutputKey("invalidKey"); + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT); + // Invalid dtype for classes + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING); + // Valid without scores + ExpectValidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING); + // Invalid dtype for scores + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT); + // Valid with both classes and scores + ExpectValidSignature(); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 652f709fe2..00c855daa3 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -462,7 +462,10 @@ py_test( size = "small", srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], + tags = [ + "no_gpu", + "no_pip_gpu", + ], deps = [ ":tensor_forest_ops_py", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 2b13343efa..f88dc51636 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -79,12 +79,15 @@ message StepInfoResult { // The step duration in picoseconds. optional uint64 duration_ps = 2; // The infeed duration in picoseconds. - // Can turn into a map if we want a variable number of ops. optional uint64 infeed_duration_ps = 3; + // The outfeed duration in picoseconds. + optional uint64 host_outfeed_ps = 8; // The start time of this step in picoseconds. optional uint64 begin_ps = 4; // The waiting time within this step in picoseconds. optional uint64 wait_duration_ps = 5; + // The unit b outfeed duration in picoseconds. + optional uint64 unit_b_outfeed_ps = 9; // The time spent on cross-replica-sum in picoseconds. optional uint64 crs_duration_ps = 6; // Percentage of unit b time spent on infeed. diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index bf807af68b..cbf6809257 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -18,8 +18,10 @@ message DynamicLearningRate { message LearningRate { oneof learning_rate { float constant = 1; - DynamicLearningRate dynamic = 2; + // DynamicLearningRate dynamic = 2; -- disabled while code is being + // rewritten. } + reserved 2; } message AdagradParameters { diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index ff88508d03..dd7f8b678f 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -170,11 +170,41 @@ class TPUDistributionStrategy(object): worker_re = re.compile('/job:([^/]+)') for device in metadata.devices: if 'TPU:0' in device.name: - self.worker_name = worker_re.search(device.name).group(1) + self._worker_name = worker_re.search(device.name).group(1) break + def _make_assignment_for_model(self, cpu_model): + """Makes a `TPUAssignment` for the passed in `cpu_model`.""" + num_cores = self._num_cores + if num_cores > 1 and cpu_model.stateful: + logging.warning( + 'Model replication does not currently support stateful models. ' + 'Degrading to a single core.') + num_cores = 1 + + return TPUAssignment( + worker_name=self._worker_name, num_cores=num_cores) + + +class TPUAssignment(object): + """This is object holding TPU resources assignment for the concrete model. + + `TPUDistributionStrategy` is responsible to create the instance of + `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on + model and input batch sizes. + """ + + def __init__(self, worker_name, num_cores): + self._worker_name = worker_name + self._num_cores = num_cores + + @property + def worker_name(self): + return self._worker_name + @property def num_towers(self): + # TODO(xiejw): Support automatically assign num_cores based on inputs. return self._num_cores @@ -495,8 +525,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager): infeed_dict[tensor] = value return infeed_dict - def __init__(self, distribution_strategy): - self._strategy = distribution_strategy + def __init__(self, tpu_assignment): + self._tpu_assignment = tpu_assignment def _split_tensors(self, inputs): """Split input data across shards. @@ -509,16 +539,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager): Returns: List of lists containing the input to feed to each TPU shard. """ - if self._strategy.num_towers == 1: + if self._tpu_assignment.num_towers == 1: return [inputs] batch_size = inputs[0].shape[0] - assert batch_size % self._strategy.num_towers == 0, ( - 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % - (batch_size, self._strategy.num_towers)) - shard_size = batch_size // self._strategy.num_towers + assert batch_size % self._tpu_assignment.num_towers == 0, ( + 'batch_size must be divisible by the number of TPU cores in use (%s ' + 'vs %s)' % (batch_size, self._tpu_assignment.num_towers)) + shard_size = batch_size // self._tpu_assignment.num_towers input_list = [] - for index in range(self._strategy.num_towers): + for index in range(self._tpu_assignment.num_towers): shard_inputs = [ x[index * shard_size:(index + 1) * shard_size] for x in inputs ] @@ -533,8 +563,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager): infeed_op = [] shard_infeed_tensors = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): infeed_tensors = [] with ops.device('/device:TPU:%d' % shard_id): for spec in input_specs: @@ -573,30 +604,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager): # TODO(saeta): Verify tpu_model_op is as expected! return {} - def __init__(self, dataset, distribution_strategy, tpu_session): + # pylint: disable=redefined-outer-name + def __init__(self, dataset, tpu_assignment, tpu_session): """Constructs a TPUDatasetInfeedManager. Must be called within a `KerasTPUModel.tpu_session` context! Args: dataset: A `tf.data.Dataset` to infeed. - distribution_strategy: The `TPUDistributionStrategy` used to configure the + tpu_assignment: The `TPUAssignment` used to configure the Keras TPU model. tpu_session: The `tf.Session` object used for running the TPU model. """ self._verify_dataset_shape(dataset) self._dataset = dataset - self._strategy = distribution_strategy + self._tpu_assignment = tpu_assignment dummy_x_shape = dataset.output_shapes[0].as_list() - dummy_x_shape[0] *= distribution_strategy.num_towers + dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() - dummy_y_shape[0] *= distribution_strategy.num_towers + dummy_y_shape[0] *= tpu_assignment.num_towers self._iterator = dataset.make_initializable_iterator() tpu_session.run(self._iterator.initializer) self._get_next_ops = [] ctrl_deps = [] - for i in range(distribution_strategy.num_towers): + for i in range(tpu_assignment.num_towers): with ops.control_dependencies(ctrl_deps): # Ensure deterministic # TODO(saeta): Ensure correct placement! get_next_op = self._iterator.get_next() @@ -676,10 +708,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager): def build_infeed_from_input_specs(self, input_specs, execution_mode): shard_infeed_tensors = self._get_next_ops - assert len(shard_infeed_tensors) == self._strategy.num_towers + assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers infeed_ops = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): infeed_ops.append( tpu_ops.infeed_enqueue_tuple( shard_infeed_tensors[shard_id], @@ -702,10 +735,10 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode, strategy): + def __init__(self, model, execution_mode, tpu_assignment): self.model = model self.execution_mode = execution_mode - self._strategy = strategy + self._tpu_assignment = tpu_assignment self._compilation_cache = {} self._cloned_model = None @@ -757,7 +790,8 @@ class TPUFunction(object): # Clone our CPU model, running within the TPU device context. with TPURewriteContext(tpu_input_map): with variable_scope.variable_scope('tpu_model_%s' % id(self.model)): - with keras_tpu_variables.replicated_scope(self._strategy.num_towers): + with keras_tpu_variables.replicated_scope( + self._tpu_assignment.num_towers): self._cloned_model = models.clone_model(self.model) # Create a copy of the optimizer for this graph. @@ -827,7 +861,7 @@ class TPUFunction(object): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self._strategy.num_towers) + _model_fn, inputs=[[]] * self._tpu_assignment.num_towers) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. @@ -835,8 +869,9 @@ class TPUFunction(object): input_specs, self.execution_mode) # Build output ops. outfeed_op = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): outfeed_op.extend( tpu_ops.outfeed_dequeue_tuple( dtypes=[spec.dtype for spec in self._outfeed_spec], @@ -886,7 +921,7 @@ class TPUFunction(object): for x, mgr in self.model._numpy_to_infeed_manager_list: if inputs[0] is x: return mgr - return TPUNumpyInfeedManager(self.model._strategy) + return TPUNumpyInfeedManager(self.model._tpu_assignment) def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager): """Looks up the corresponding `TPUModelOp` for a given `input_specs`. @@ -958,7 +993,7 @@ class TPUFunction(object): outputs = [[]] * len(self._outfeed_spec) outputs_per_replica = len(self._outfeed_spec) - for i in range(self._strategy.num_towers): + for i in range(self._tpu_assignment.num_towers): output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) * outputs_per_replica] for j in range(outputs_per_replica): @@ -967,7 +1002,7 @@ class TPUFunction(object): return [np.concatenate(group) for group in outputs] else: return outfeed_outputs[:len(outfeed_outputs) // - self._strategy.num_towers] + self._tpu_assignment.num_towers] def __call__(self, inputs): """__call__ executes the function on the computational hardware. @@ -1119,11 +1154,11 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None - self._strategy = strategy - cluster_resolver = self._strategy._tpu_cluster_resolver + cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() self._cpu_model = cpu_model + self._tpu_assignment = strategy._make_assignment_for_model(cpu_model) self._tpu_model = None self._tpu_weights_initialized = False @@ -1146,7 +1181,7 @@ class KerasTPUModel(models.Model): return { 'cpu_model': self._cpu_model, 'tpu_name_or_address': self._tpu_name_or_address, - 'strategy': self._strategy, + 'tpu_assignment': self._tpu_assignment, } def compile(self, @@ -1207,7 +1242,7 @@ class KerasTPUModel(models.Model): '/keras') if callable(x): with self.tpu_session() as sess,\ - ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): dataset = x() if steps_per_epoch is None: raise ValueError('When using tf.data as input to a model, you ' @@ -1215,7 +1250,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1236,7 +1272,8 @@ class KerasTPUModel(models.Model): if validation_steps is None: raise ValueError('When using tf.data as validation for a model, you ' 'should specify the validation_steps argument.') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. val_x = infeed_manager.dummy_x @@ -1313,7 +1350,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1740,20 +1778,24 @@ class KerasTPUModel(models.Model): def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( - self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy) + self, + model_fn_lib.ModeKeys.TRAIN, + tpu_assignment=self._tpu_assignment) return self.train_function def _make_test_function(self): if not self.test_function: self.test_function = TPUFunction( - self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy) + self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( - self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy) + self, + model_fn_lib.ModeKeys.PREDICT, + tpu_assignment=self._tpu_assignment) return self.predict_function def _initialize_weights(self, cloned_model): @@ -1825,6 +1867,7 @@ class KerasTPUModel(models.Model): self._session.close() +# pylint: disable=bad-continuation def _validate_shapes(model): """Validate that all layers in `model` have constant shape.""" for layer in model.layers: @@ -1852,10 +1895,13 @@ Layer: %(layer)s Input shape: %(input_shape)s Output shape: %(output_shape)s """ % { - 'layer': layer, - 'input_shape': layer.input_shape, - 'output_shape': layer.output_shape - }) + 'layer': layer, + 'input_shape': layer.input_shape, + 'output_shape': layer.output_shape + }) + + +# pylint: enable=bad-continuation @experimental diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5c314f359c..c06fea130f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -695,6 +695,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib_internal", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -3220,7 +3221,6 @@ tf_cc_tests( "lib/gtl/edit_distance_test.cc", "lib/gtl/flatmap_test.cc", "lib/gtl/flatset_test.cc", - "lib/gtl/inlined_vector_test.cc", "lib/gtl/int_type_test.cc", "lib/gtl/iterator_range_test.cc", "lib/gtl/manual_constructor_test.cc", diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 3bf0532491..84c6285bbe 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -596,7 +596,7 @@ string BFCAllocator::RenderOccupancy() { region_offset += region.memory_size(); } - return std::string(rendered, resolution); + return string(rendered, resolution); } void BFCAllocator::DumpMemoryLog(size_t num_bytes) { diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 39a3b49cd1..879a794368 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, std::unique_ptr<DeviceMgr> device_mgr, + bool async, + std::unique_ptr<const DeviceMgr> device_mgr, Rendezvous* rendezvous) + : EagerContext(opts, default_policy, async, device_mgr.release(), + /*device_mgr_owned*/ true, rendezvous) {} + +EagerContext::EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, + bool async, const DeviceMgr* device_mgr, + bool device_mgr_owned, Rendezvous* rendezvous) : policy_(default_policy), - local_device_manager_(std::move(device_mgr)), - local_unowned_device_manager_(nullptr), - devices_(local_device_manager_->ListDevices()), + devices_(device_mgr->ListDevices()), rendezvous_(rendezvous), thread_pool_(NewThreadPoolFromSessionOptions(opts)), pflr_(new ProcessFunctionLibraryRuntime( - local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, - &func_lib_def_, {}, thread_pool_.get())), + device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, + thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), num_active_steps_(0), async_default_(async), env_(opts.env), use_send_tensor_rpc_(false) { + if (device_mgr_owned) { + local_device_manager_.reset(device_mgr); + local_unowned_device_manager_ = nullptr; + } else { + local_unowned_device_manager_ = device_mgr; + } InitDeviceMapAndAsync(); if (opts.config.inter_op_parallelism_threads() > 0) { runner_ = [this](std::function<void()> closure) { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3c95ac590d..eb6eb0d55a 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy { class EagerContext { public: - explicit EagerContext(const SessionOptions& opts, - ContextDevicePlacementPolicy default_policy, bool async, - std::unique_ptr<DeviceMgr> device_mgr, - Rendezvous* rendezvous); + // TODO: remove this constructor once we migrate all callers to the next one. + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<const DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous); + ~EagerContext(); // Returns the function library runtime for the given device. @@ -207,8 +214,8 @@ class EagerContext { thread_local_policies_ GUARDED_BY(policy_map_mu_); // Only one of the below is set. - std::unique_ptr<DeviceMgr> local_device_manager_; - DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<const DeviceMgr> local_device_manager_; + const DeviceMgr* local_unowned_device_manager_; std::unique_ptr<DeviceMgr> remote_device_manager_; // Devices owned by device_manager diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a1797fa19..f9aef3af70 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous { } mutex_lock l(mu_); - string edge_name = std::string(parsed.edge_name); + string edge_name(parsed.edge_name); if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous { Tensor tensor; Status status = Status::OK(); { - string key = std::string(parsed.edge_name); + string key(parsed.edge_name); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 7f3c25d81d..3b59995433 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -254,9 +254,11 @@ class ColocationGraph { old_root_member.device_name, allow_soft_placement_); if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); + return errors::InvalidArgument( + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), ": ", + s.error_message()); } // Ensure that the common root has at least one supported device @@ -267,8 +269,10 @@ class ColocationGraph { old_root_member.supported_device_types); if (new_root_member.supported_device_types.empty()) { return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), + " because no device type supports both of those nodes and the " "other nodes colocated with them.", DebugInfo(x_root), DebugInfo(y_root)); } @@ -376,8 +380,9 @@ class ColocationGraph { // merged set device is different, so print both. return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->requested_device(), - "' because the node was colocated with a group of nodes that " + node->requested_device(), "' because the node ", + errors::FormatColocationNodeForError(node->name()), + " was colocated with a group of nodes that ", "required incompatible device '", DeviceNameUtils::ParsedNameToString( members_[node_root].device_name), @@ -809,10 +814,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -856,10 +861,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -925,21 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } -bool Placer::ClientHandlesErrorFormatting() const { - return options_ != nullptr && - options_->config.experimental().client_handles_error_formatting(); -} - -// Returns the node name in single quotes. If the client handles formatted -// errors, appends a formatting tag which the client will reformat into, for -// example, " (defined at filename:123)". -// TODO(shikharagarwal): Remove this function once -// client_handles_error_formatting flag is removed. -string Placer::RichNodeName(const Node* node) const { - if (ClientHandlesErrorFormatting()) { - return errors::FormatNodeNameForError(node->name()); - } - return strings::StrCat("'", node->name(), "'"); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index cefcdd25db..f97ffe7372 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -87,8 +87,6 @@ class Placer { // placement if the SessionOptions entry in 'options_' requests it. void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - bool ClientHandlesErrorFormatting() const; - string RichNodeName(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 83d27e2730..9b8a95e3b6 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { } Status s = Place(&g); - EXPECT_TRUE( - str_util::StrContains(s.error_message(), - "Cannot colocate nodes 'foo' and 'in' because no " - "device type supports both of those nodes and the " - "other nodes colocated with them")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Cannot colocate nodes {{colocation_node foo}} and " + "{{colocation_node in}} because no device type supports both of those " + "nodes and the other nodes colocated with them")); } TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { @@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { Status s = Place(&g); EXPECT_TRUE(str_util::StrContains( s.error_message(), - "Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); + "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node " + "assign3}} because no device type supports both of those nodes and the " + "other nodes colocated with them.")); } TEST_F(PlacerTest, TestColocationAndReferenceConnections) { @@ -1154,35 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) { } SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - true); Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); LOG(WARNING) << s.error_message(); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation {{node in}}")); -} - -// Test that the "Cannot assign a device" error message does not contain a -// format tag when not it shouldn't -TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - false); - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation 'in'")); - EXPECT_FALSE(str_util::StrContains( - s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot assign a device for operation in")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}")); } // Test that placement fails when a node requests an explicit device that is not @@ -1288,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot colocate nodes 'var' and 'assign'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot colocate nodes {{colocation_node " + "var}} and {{colocation_node assign}}")); } // Test that a generator node follows its consumers (where there are several diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 10a24ed14c..fdad8de8d6 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 65ff356e73..5b1915755d 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names, // Save only the tensors in output_names in the session. for (const string& name : output_names) { TensorId id(ParseTensorName(name)); - const string& op_name = std::string(id.first); + const string op_name(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 9c2510e6a9..836cb8ed14 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -176,7 +176,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -205,7 +205,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -252,7 +252,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const StringPiece device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); + const int gpu_id = ExtractGpuWithoutStream(string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..bd0284d43a 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 653b088b1d..e78239bd43 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } -bool IsElementWiseMonotonic(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_monotonic_ops = +// Returns true if node represents a unary elementwise function that is +// monotonic. If *is_non_decreasing is true, the function is non-decreasing, +// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, +// e.g. inv. +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { + static const std::unordered_set<string>* monotonic_non_decreasing_ops = CHECK_NOTNULL((new std::unordered_set<string>{ - "Relu", - "Relu6", - "Sigmoid", - "Sqrt", - "Tanh", + "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", + "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", + "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", + })); + static const std::unordered_set<string>* monotonic_non_increasing_ops = + CHECK_NOTNULL((new std::unordered_set<string>{ + "Inv", + "Reciprocal", + "Erfc", + "Rsqrt", + "Neg", })); - return element_wise_monotonic_ops->count(node.op()) > 0; + if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = true; + } + return true; + } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = false; + } + return true; + } + return false; } bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 94439265c9..25ab6b65ac 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); -bool IsElementWiseMonotonic(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fed88d536..65947ddce5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2706,8 +2706,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. + bool is_non_decreasing = false; if (!IsInPreserveSet(*inner_function) && - IsElementWiseMonotonic(*inner_function) && + IsElementWiseMonotonic(*inner_function, &is_non_decreasing) && ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; @@ -2719,7 +2720,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { UpdateConsumers(reduction_node, inner_function->name()); ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), reduction_node->name()); - + if (!is_non_decreasing) { + // Flip Min<->Max if the function is non-increasing, e.g. + // Max(Neg(x)) = Neg(Min(x)). + const string opposite = IsMax(*reduction_node) ? "Min" : "Max"; + reduction_node->set_op(opposite); + } AddToOptimizationQueue(reduction_node); AddToOptimizationQueue(inner_function); AddToOptimizationQueue(inner_input); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bfccc0affd..39517edc06 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -3248,6 +3248,48 @@ TEST_F(ArithmeticOptimizerTest, VerifyGraphsMatch(item.graph, output, __LINE__); } +TEST_F(ArithmeticOptimizerTest, + OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output neg = ops::Neg(s.WithOpName("neg"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "neg") { + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Min", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a2c363ea6e..a428aea7f5 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType( } GrapplerFunctionItem::GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int graph_def_version, - bool is_stateful, GraphDef&& function_body) - : description_(description), - func_attr_(func_attr), - input_arg_expansions_(input_arg_expansions), - output_arg_expansions_(output_arg_expansions), + string func_name, string description, AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, const int graph_def_version, + const bool is_stateful, GraphDef&& function_body) + : description_(std::move(description)), + func_attr_(std::move(func_attr)), + input_arg_expansions_(std::move(input_arg_expansions)), + output_arg_expansions_(std::move(output_arg_expansions)), is_stateful_(is_stateful) { - id = func_name; - keep_ops = keep_nodes; - // Swap the graph body. - graph.Swap(&function_body); + // Move assign GrapplerItem members. + keep_ops = std::move(keep_nodes); + id = std::move(func_name); + graph = std::move(function_body); + graph.mutable_versions()->set_producer(graph_def_version); // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { @@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), - inputs, outputs, keep_nodes, graph_def_version, is_stateful, - std::move(function_body)); + std::move(inputs), std::move(outputs), std::move(keep_nodes), + graph_def_version, is_stateful, std::move(function_body)); return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 61588ceb83..733caf325f 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation { class GrapplerFunctionItem : public GrapplerItem { public: GrapplerFunctionItem() = default; - GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int versions, - bool is_stateful, GraphDef&& function_body); + GrapplerFunctionItem(string func_name, string description, + AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, int graph_def_version, + bool is_stateful, GraphDef&& function_body); const string& description() const; diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e7b3d0c92f..3a1ac73f64 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -51,6 +51,7 @@ cc_library( hdrs = ["captured_function.h"], deps = [ ":dataset", + ":single_threaded_executor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -61,6 +62,42 @@ cc_library( ) cc_library( + name = "single_threaded_executor", + srcs = ["single_threaded_executor.cc"], + hdrs = ["single_threaded_executor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "single_threaded_executor_test", + srcs = ["single_threaded_executor_test.cc"], + deps = [ + ":single_threaded_executor", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + +cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], hdrs = ["window_dataset.h"], diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index abdf6ee4e8..186740c2ac 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -28,7 +28,16 @@ namespace tensorflow { Status CapturedFunction::Create( const NameAttrList& func, std::vector<Tensor> captured_inputs, std::unique_ptr<CapturedFunction>* out_function) { - out_function->reset(new CapturedFunction(func, std::move(captured_inputs))); + return Create(func, std::move(captured_inputs), true, out_function); +} + +/* static */ +Status CapturedFunction::Create( + const NameAttrList& func, std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism, + std::unique_ptr<CapturedFunction>* out_function) { + out_function->reset(new CapturedFunction(func, std::move(captured_inputs), + use_inter_op_parallelism)); return Status::OK(); } @@ -272,6 +281,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) { inst_opts.overlay_lib = ctx->function_library().get(); inst_opts.state_handle = std::to_string(random::New64()); inst_opts.create_kernels_eagerly = true; + if (!use_inter_op_parallelism_) { + inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; + } Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle_)); TF_RETURN_IF_ERROR(s); @@ -398,10 +410,12 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, } CapturedFunction::CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs) + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism) : func_(func), lib_(nullptr), f_handle_(kInvalidHandle), - captured_inputs_(std::move(captured_inputs)) {} + captured_inputs_(std::move(captured_inputs)), + use_inter_op_parallelism_(use_inter_op_parallelism) {} } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index c95f2b1c01..ae6bdfc2a0 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -48,6 +48,15 @@ class CapturedFunction { std::vector<Tensor> captured_inputs, std::unique_ptr<CapturedFunction>* out_function); + // Creates a new instance from a list of named attributes and captured inputs. + // + // If `low_latency_hint` is true, the runtime may use an executor that is + // optimized for small functions. + static Status Create(const NameAttrList& func, + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism, + std::unique_ptr<CapturedFunction>* out_function); + // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. static Status Create(const NameAttrList& func, OpKernelContext* ctx, @@ -114,7 +123,8 @@ class CapturedFunction { private: CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs); + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism); Status GetHandle(IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle); @@ -126,6 +136,7 @@ class CapturedFunction { const std::vector<Tensor> captured_inputs_; DataTypeSlice ret_types_; std::function<void(std::function<void()>)> captured_runner_ = nullptr; + const bool use_inter_op_parallelism_; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 7f8182d917..6c45fcafcc 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -34,6 +34,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -48,7 +50,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + func_, std::move(other_arguments), + use_inter_op_parallelism_, &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); @@ -187,6 +190,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList func_; + bool use_inter_op_parallelism_; }; REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc new file mode 100644 index 0000000000..e785b8b4d5 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -0,0 +1,378 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; +typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec; +typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; + +class SingleThreadedExecutorImpl : public Executor { + public: + explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) + : params_(params) {} + + ~SingleThreadedExecutorImpl() override { + for (const KernelState& kernel_state : kernels_) { + params_.delete_kernel(kernel_state.kernel); + } + } + + Status Initialize(const Graph& graph) { + // Topologicially sort `graph` to get a sequence of OpKernels. + std::vector<Node*> ordered_nodes; + ordered_nodes.reserve(graph.num_nodes()); + GetReversePostOrder(graph, &ordered_nodes); + + if (ordered_nodes.size() != graph.num_nodes()) { + return errors::InvalidArgument("Graph had ", graph.num_nodes(), + " but reverse post-order had ", + ordered_nodes.size()); + } + + kernels_.resize(ordered_nodes.size()); + + std::unordered_map<Node*, size_t> node_to_index_map; + + // Create the kernel and input-related structures for each node in `graph`. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + node_to_index_map[n] = i; + + for (DataType dt : n->output_types()) { + if (IsRefType(dt)) { + return errors::Unimplemented( + "Single-threaded executor does not support reference-typed " + "edges."); + } + } + + if (n->IsControlFlow()) { + return errors::Unimplemented( + "Single-threaded executor does not support control flow."); + } + if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) { + return errors::Unimplemented( + "Single-threaded executor does not support partitioned graphs."); + } + if (n->IsCollective()) { + return errors::Unimplemented( + "Single-threaded executor does not support collective ops."); + } + + KernelState& kernel_state = kernels_[i]; + TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + kernel_state.num_inputs = n->num_inputs(); + kernel_state.num_outputs = n->num_outputs(); + + if (i == 0) { + kernel_state.input_start_index = 0; + } else { + const KernelState& previous_kernel_state = kernels_[i - 1]; + kernel_state.input_start_index = + previous_kernel_state.input_start_index + + previous_kernel_state.num_inputs; + } + } + + // Build the mapping from each node output to the input slot for the + // corresponding destination node. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + KernelState& kernel_state = kernels_[i]; + kernel_state.output_locations.resize(kernel_state.num_outputs); + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + kernel_state.output_locations[e->src_output()].push_back( + kernels_[node_to_index_map[e->dst()]].input_start_index + + e->dst_input()); + } + } + + // Compute allocator attributes for each node output, and corresponding + // node input. + kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); + AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); + + OpKernel* op_kernel = kernel_state.kernel; + for (int out = 0; out < n->num_outputs(); out++) { + DCHECK_LT(out, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; + if (on_host) { + AllocatorAttributes h; + h.set_on_host(on_host); + attrs[out].Merge(h); + } + } + } + + if (!kernels_.empty()) { + const KernelState& last_kernel_state = kernels_.back(); + total_num_inputs_ = + last_kernel_state.input_start_index + last_kernel_state.num_inputs; + input_alloc_attrs_.resize(total_num_inputs_); + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { + for (size_t output_location : kernels_[i].output_locations[j]) { + input_alloc_attrs_[output_location] = + kernels_[i].output_alloc_attrs[j]; + } + } + } + } else { + total_num_inputs_ = 0; + } + return Status::OK(); + } + + // TODO(mrry): Consider specializing the implementation of Executor::Run() + // instead, to avoid unnecessary atomic operations in the callback when + // running synchronously. + void RunAsync(const Args& args, DoneCallback done) override { + // The inputs to each kernel are stored contiguously in `inputs`. + // + // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to + // determine the range of elements in this vector that correspond to + // the inputs of `kernels_[i]`. + // + // This vector has the following layout: + // + // * Kernel 0, input 0. + // * Kernel 0, input 1. + // * ... + // * Kernel 0, input `kernels_[0].num_inputs - 1`. + // * Kernel 1, input 0. + // * ... + // * Kernel 1, input `kernels_[1].num_inputs - 1`. + // * ... + // * Kernel `kernels_.size() - 1`, input 0. + // * ... + // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. + // + // Note that kernels with zero inputs do not correspond to any elements in + // this vector. + // + // We use `ManualConstructor<Tensor>` to avoid the overhead of + // default-constructing an invalid `Tensor` for each slot at the beginning + // of execution: + // * Elements are initialized when the outputs of a kernel execution are + // propagated to the inputs of kernels that depend on them. + // * The elements corresponding to the inputs for kernel `i` are destroyed + // after kernel `i` executes. + // * In an error case (see below), we use the connectivity information in + // `KernelState::output_locations` to determine which locations have been + // initialized, and manually destroy them. + std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_); + + // TODO(mrry): Can we avoid copying into these vectors? Consider modifying + // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. + TensorValueVec node_inputs; + DeviceContextVec input_device_contexts; + AllocatorAttributeVec input_alloc_attrs; + + // Prepare the parameters that will be the same for all kernels. + OpKernelContext::Params params; + params.step_id = args.step_id; + Device* device = params_.device; + params.device = device; + params.log_memory = false; // TODO(mrry): Too severe? + params.record_tensor_accesses = false; // TODO(mrry): Too severe? + params.rendezvous = args.rendezvous; + params.session_state = args.session_state; + params.tensor_store = args.tensor_store; + params.cancellation_manager = args.cancellation_manager; + // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor + // allocations that it performs. Consider specializing its handling in the + // executor. + params.call_frame = args.call_frame; + params.function_library = params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_container = args.step_container; + params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? + params.inputs = &node_inputs; + params.input_device_contexts = &input_device_contexts; + params.input_alloc_attrs = &input_alloc_attrs; + + Args::Runner runner_copy = args.runner; + params.runner = &runner_copy; + params.stats_collector = args.stats_collector; + + // NOTE(mrry): We are assuming that the graph is loopless and condless. + params.frame_iter = FrameAndIter(0, 0); + params.is_input_dead = false; + + // TODO(mrry): Add non-default device context inference. + params.op_device_context = nullptr; + // TODO(mrry): Consider implementing forwarding. + params.forward_from_array = nullptr; + + // Execute the kernels one-at-a-time in topological order. + for (size_t i = 0; i < kernels_.size(); ++i) { + const KernelState& kernel_state = kernels_[i]; + + // Prepare the per-kernel parameters. + const size_t input_start_index = kernel_state.input_start_index; + const size_t num_inputs = kernel_state.num_inputs; + const size_t num_outputs = kernel_state.num_outputs; + + node_inputs.clear(); + node_inputs.resize(num_inputs); + input_alloc_attrs.clear(); + input_alloc_attrs.resize(num_inputs); + for (size_t j = 0; j < num_inputs; ++j) { + auto t = inputs[input_start_index + j].get(); + node_inputs[j].tensor = t; + input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; + } + params.op_kernel = kernel_state.kernel; + input_device_contexts.clear(); + input_device_contexts.resize(num_inputs); + params.output_attr_array = kernel_state.output_alloc_attrs.data(); + OpKernelContext ctx(¶ms, num_outputs); + + // Actually execute the kernel. + device->Compute(kernel_state.kernel, &ctx); + + if (!ctx.status().ok()) { + // On failure, we must manually free all intermediate tensors. We have + // already freed all the inputs for kernels up to (but not including) + // the `i`th kernel. We scan through the previously executed kernels and + // destroy any tensors that were destined to be the input for a kernel + // that has not yet executed. + for (size_t j = 0; j < i; ++j) { + const KernelState& executed_kernel_state = kernels_[j]; + for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) { + for (size_t output_location : + executed_kernel_state.output_locations[k]) { + if (output_location >= input_start_index) { + // Only destroy an output location if it is an input to an + // operation that has not yet executed. + inputs[output_location].Destroy(); + } + } + } + } + done(ctx.status()); + return; + } + + // Free the inputs to the current kernel. + for (size_t j = 0; j < num_inputs; ++j) { + inputs[input_start_index + j].Destroy(); + } + + // Forward the outputs of the kernel to the inputs of subsequent kernels. + for (size_t j = 0; j < num_outputs; ++j) { + TensorValue val = ctx.release_output(j); + // TODO(mrry): Consider flattening the `output_locations` vector + // to improve the cache-friendliness of this loop. + for (size_t output_location : kernel_state.output_locations[j]) { + // TODO(mrry): Validate that the types match the expected values or + // ensure that the necessary validation has already happened. + inputs[output_location].Init(*val.tensor); + } + delete val.tensor; + } + } + done(Status::OK()); + } + + private: + const LocalExecutorParams params_; + + // All following members are read-only after Initialize(). + + // The sum of the number of inputs for each node in the graph. This determines + // the length of the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + size_t total_num_inputs_; + + // Represents cached graph structure state for each kernel. + struct KernelState { + // The kernel object. Not owned. + // + // This pointer is managed by `params_.create_kernel()` and + // `params_.delete_kernel()`. + OpKernel* kernel; + + // These fields determine the range of elements in `inputs` that corresponds + // to the inputs of `kernel`. + size_t input_start_index; + size_t num_inputs; + + size_t num_outputs; + + // For the `j`th output of `kernel`, `output_locations[j]` contains the + // locations in the flat `inputs` vector to which that output must be + // copied. See comment at the beginning of `RunAsync()` for details. + std::vector<std::vector<size_t>> + output_locations; // Length = `num_outputs`. + + // Memory space information for each output of `kernel`. + std::vector<AllocatorAttributes> + output_alloc_attrs; // Length = `num_outputs`. + }; + std::vector<KernelState> kernels_; + + // Memory space information for each input. This information is stored in the + // same order as the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + std::vector<AllocatorAttributes> + input_alloc_attrs_; // Length = `total_num_inputs_`. +}; + +class SingleThreadedExecutorRegistrar { + public: + SingleThreadedExecutorRegistrar() { + ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + Executor* ret; + TF_RETURN_IF_ERROR( + NewSingleThreadedExecutor(params, std::move(graph), &ret)); + out_executor->reset(ret); + return Status::OK(); + } + }; +}; +static SingleThreadedExecutorRegistrar registrar; + +} // namespace + +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor) { + std::unique_ptr<SingleThreadedExecutorImpl> impl( + new SingleThreadedExecutorImpl(params)); + TF_RETURN_IF_ERROR(impl->Initialize(*graph)); + *executor = impl.release(); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h new file mode 100644 index 0000000000..15836b24c9 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/executor.h" + +namespace tensorflow { + +// Creates a new `Executor` for executing `graph` synchronously on the caller +// thread. +// +// NOTE(mrry): The returned executor is optimized to impose low overhead on +// graphs that perform a small amount of work (e.g. <15us of work per graph on +// present architectures). It eschews concurrency, because issuing work to +// multiple threads can dominate the cost of executing small ops synchronously, +// and because contention in the executor data structures can reduce throughput +// (in terms of ops executed per unit time). +// +// However, the current implementation has the following limitations: +// +// 1. Reference-typed tensors are not supported and will not be supported in +// future. +// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not +// currently supported. The current plan is to extend support to "functional" +// control flow after the TensorFlow APIs transition to building graphs in +// that form (e.g. `tf.cond_v2()`). +// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported. +// The present implementation executes kernels one at a time in topological +// order, and cannot currently distinguish between disconnected subgraphs +// that are logically connected by subgraphs on a different device. +// 4. Memory logging is not currently supported. +// 5. Allocation forwarding is not currently supported. +// 6. Non-default device contexts are not currently supported. In effect, this +// limits the executor to CPU devices. +// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null +// are not currently supported. +// +// The single-threaded executor is primarily suitable for executing simple +// TensorFlow functions, such as one might find in a `tf.data` pipeline. +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc new file mode 100644 index 0000000000..f8b5769197 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include <algorithm> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(std::unique_ptr<const Graph> graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_)); + runner_ = [](std::function<void()> fn) { fn(); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.runner = runner_; + return exec_->Run(args); + } + + Status Run(CallFrameInterface* call_frame) { + Executor::Args args; + args.call_frame = call_frame; + args.runner = runner_; + return exec_->Run(args); + } + + Device* device_ = nullptr; + Executor* exec_ = nullptr; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor<float> +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar<float>()() = val; + return tensor; +} + +// A int32 val -> Tensor<int32> +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar<int32>()() = val; + return tensor; +} + +// A bool val -> Tensor<bool> +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar<bool>()() = val; + return tensor; +} + +// A double val -> Tensor<double> +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar<double>()() = val; + return tensor; +} + +// Tensor<float> -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar<float>()(); +} + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + TF_CHECK_OK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result)); + return result; +} + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Retval(g.get(), 0, tmp); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto v = test::graph::Arg(g.get(), 0, DT_FLOAT); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g.get(), v, v); + } + // out <- v10 + test::graph::Retval(g.get(), 0, v); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + // a = 1.0 + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Arg(g, 0, DT_FLOAT); + std::vector<Node*> nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(0, 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Retval(g, 0, nodes.back()); + FixupSourceAndSinkEdges(g); +} + +TEST_F(ExecutorTest, RandomTree) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(4096.0, V(retvals[0])); +} + +TEST_F(ExecutorTest, OpError) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto zero = test::graph::Constant(g.get(), V(0.0)); + auto inf = test::graph::Unary(g.get(), "Reciprocal", zero); + auto check = test::graph::CheckNumerics(g.get(), inf, "message"); + auto two = test::graph::Constant(g.get(), V(2.0)); + test::graph::Binary(g.get(), "Mul", check, two); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({}, {}); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame))); +} + +static void BM_executor(int iters, int width, int depth) { +#ifdef PLATFORM_GOOGLE + BenchmarkUseRealTime(); +#endif // PLATFORM_GOOGLE + Graph* g = new Graph(OpRegistry::Global()); + random::PhiloxRandom philox(1729, 17); + random::SimplePhilox rand(&philox); + uint64 cur = 0; + uint32 r = 1 + rand.Rand32() % width; + std::vector<Node*> ready_nodes; + for (int i = 0; i < r; ++i) { + ready_nodes.push_back(test::graph::NoOp(g, {})); + ++cur; + } + for (int i = 0; i < depth; ++i) { + std::random_shuffle(ready_nodes.begin(), ready_nodes.end()); + r = 1 + rand.Rand32() % (ready_nodes.size()); + std::vector<Node*> control_inputs; + for (int j = 0; j < r; ++j) { + control_inputs.push_back(ready_nodes.back()); + ready_nodes.pop_back(); + } + Node* n = test::graph::NoOp(g, control_inputs); + ++cur; + r = 1 + rand.Rand32() % width; + for (int j = 0; j < r; ++j) { + ready_nodes.push_back(test::graph::NoOp(g, {n})); + ++cur; + } + } + FixupSourceAndSinkEdges(g); +#ifdef PLATFORM_GOOGLE + SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); + SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .Run(iters); +} + +// Tall skinny graphs +BENCHMARK(BM_executor)->ArgPair(16, 1024); +BENCHMARK(BM_executor)->ArgPair(32, 8192); + +// Short fat graphs +BENCHMARK(BM_executor)->ArgPair(1024, 16); +BENCHMARK(BM_executor)->ArgPair(8192, 32); + +// Tall fat graph +BENCHMARK(BM_executor)->ArgPair(1024, 1024); + +// TODO(mrry): This benchmark currently crashes with a use-after free, because +// test::Benchmark::RunWithArgs() assumes that the executor will take ownership +// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the +// duration of the benchmark. Since the single threaded executor does not retain +// a copy of the graph, this fails. +// +// TODO(mrry): Add support for Arg/Retval "function call convention" in +// `test::Benchmark::RunWithArgs()`. +#if 0 +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/gpu:0" + +static void BM_FeedInputFetchOutput(int iters) { + Graph* g = new Graph(OpRegistry::Global()); + // z = x + y: x and y are provided as benchmark inputs. z is the + // output of the benchmark. Conceptually, the caller is ALICE, the + // benchmark is BOB. + Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB); + Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB); + Node* sum = test::graph::Add(g, x, y); + Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE); + FixupSourceAndSinkEdges(g); + Tensor val(DT_FLOAT, TensorShape({})); + val.scalar<float>()() = 3.14; + SetBenchmarkItemsProcessed(static_cast<int64>(iters)); + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .RunWithArgs({{x, val}, {y, val}}, {z}, iters); +} +BENCHMARK(BM_FeedInputFetchOutput); +#endif + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 33ed5522d0..d705e82b0d 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp { TensorShape shape({1}); OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); output_tensor->vec<int64>()(0) = nan_count; - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } }; @@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp { bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 && positive_inf_count == 0; if (!mute) { - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } } diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index e13e548f86..3ebeb7be2b 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -323,47 +323,34 @@ CuboidConvolutionBackwardInput( template <typename OutputBackward, typename Input> EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits<OutputBackward>::Layout == ColMajor, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, + TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 5>, + const TensorContractionOp< + const array<IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, - const TensorContractionOp< - const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input>, - const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, - const TensorVolumePatchOp< - Dynamic, Dynamic, Dynamic, - const OutputBackward> > > > > >, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, + const DSizes<typename internal::traits<Input>::Index, 2>, + const OutputBackward>, + const TensorShufflingOp< + const array<typename internal::traits<OutputBackward>::Index, + 2>, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, + const Input> > > > >, + TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 5>, + const TensorContractionOp< + const array<IndexPair<typename internal::traits<Input>::Index>, 1>, + const TensorShufflingOp< + const array<typename internal::traits<OutputBackward>::Index, + 2>, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, + const Input> > >, const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, - const TensorContractionOp< - const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, - const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> >, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input> > > > > >::type + const DSizes<typename internal::traits<Input>::Index, 2>, + const OutputBackward> > > >::type CuboidConvolutionBackwardKernel( const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelPlanes, @@ -406,213 +393,114 @@ CuboidConvolutionBackwardKernel( const TensorIndex outputCols = isColMajor ? out.dimension(3) : out.dimension(NumDims - 4); + // Number of filters. This is the same as the output depth. const TensorIndex kernelFilters = isColMajor ? out.dimension(0) : out.dimension(NumDims - 1); + // Number of channels. This is the same as the input depth. const TensorIndex kernelChannels = isColMajor ? in.dimension(0) : in.dimension(NumDims - 1); - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi<TensorIndex>( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi<TensorIndex>( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi<TensorIndex>( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; - } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; - } - - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); - eigen_assert(padding_top >= 0); - eigen_assert(padding_left >= 0); - eigen_assert(padding_bottom >= 0); - eigen_assert(padding_right >= 0); - - // The output_backward has dimensions out_depth X out_plaens X out_rows X - // out_cols X OTHERS - // When we extract the image patches from output_backward (with input as the - // kernel), it will have dimensions - // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes * - // kernel_rows * kernel_cols) X OTHERS - DSizes<TensorIndex, 4> pre_contract_dims; + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + const TensorIndex padPlanes = numext::maxi<Index>( + 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes); + const TensorIndex padRows = numext::maxi<Index>( + 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi<Index>( + 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols); + + const TensorIndex padding_top_z = padPlanes / 2; + const TensorIndex padding_bottom_z = padPlanes - padding_top_z; + const TensorIndex padding_top = padRows / 2; + const TensorIndex padding_bottom = padRows - padding_top; + const TensorIndex padding_left = padCols / 2; + const TensorIndex padding_right = padCols - padding_left; + + // Reshaped output_backward before contraction. + DSizes<TensorIndex, 2> output_dims; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = inputRows * inputCols * inputPlanes; - pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[3] = 1; + output_dims[0] = kernelFilters; + output_dims[1] = outputPlanes * outputRows * outputCols; for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[3] *= out.dimension(i); + output_dims[1] *= out.dimension(i); } } else { - pre_contract_dims[3] = kernelFilters; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = 1; + output_dims[1] = kernelFilters; + output_dims[0] = outputCols * outputRows * outputPlanes; for (int i = 0; i < NumDims - 4; ++i) { - pre_contract_dims[0] *= out.dimension(i); + output_dims[0] *= out.dimension(i); } } - // The input has dimensions in_depth X (input_planes * input_rows * - // input_cols) X OTHERS - DSizes<TensorIndex, 3> input_dims; + // Reshaped extract_volume_patches(in) + DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { - input_dims[0] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[2] = 1; + pre_contract_dims[0] = + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = outputPlanes * outputRows * outputCols; for (int i = 4; i < NumDims; ++i) { - input_dims[2] *= in.dimension(i); + pre_contract_dims[1] *= in.dimension(i); } - eigen_assert(input_dims[2] == pre_contract_dims[3]); + eigen_assert(output_dims[1] == pre_contract_dims[1]); } else { - input_dims[2] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[0] = 1; + pre_contract_dims[1] = + kernelCols * kernelRows * kernelPlanes * kernelChannels; + pre_contract_dims[0] = outputCols * outputRows * outputPlanes; for (int i = 0; i < NumDims - 4; ++i) { - input_dims[0] *= in.dimension(i); + pre_contract_dims[0] *= in.dimension(i); } - eigen_assert(input_dims[0] == pre_contract_dims[0]); + eigen_assert(output_dims[0] == pre_contract_dims[0]); } - // We will contract along dimensions (1, 2) in and (1, 3) in out, if - // this is col-major. - // For row-major, it's dimensions (0, 1) in and (0, 2) in out. - array<IndexPair<TensorIndex>, 2> contract_dims; - if (isColMajor) { - // col-major: in.contract(output.patches) - contract_dims[0] = IndexPair<TensorIndex>(1, 1); - contract_dims[1] = IndexPair<TensorIndex>(2, 3); - } else { - // row-major: output.patches.contract(in) - contract_dims[0] = IndexPair<TensorIndex>(0, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 1); - } + array<TensorIndex, 2> shuffle_dims; + shuffle_dims[0] = 1; + shuffle_dims[1] = 0; - // After the contraction, the kernel will have dimension - // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols - // We will need to shuffle the first two dimensions and reverse the spatial - // dimensions. - // The end shape is: - // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols + array<IndexPair<TensorIndex>, 1> contract_dims; + contract_dims[0] = IndexPair<TensorIndex>(1, 0); - // This is the shape of the kernel *before* the shuffling. DSizes<TensorIndex, 5> kernel_dims; if (isColMajor) { - kernel_dims[0] = kernelChannels; - kernel_dims[1] = kernelFilters; + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels; kernel_dims[2] = kernelPlanes; kernel_dims[3] = kernelRows; kernel_dims[4] = kernelCols; } else { - kernel_dims[0] = kernelCols; - kernel_dims[1] = kernelRows; + kernel_dims[4] = kernelFilters; + kernel_dims[3] = kernelChannels; kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelFilters; - kernel_dims[4] = kernelChannels; - } - - // Flip filters and channels. - array<TensorIndex, 5> kernel_shuffle; - if (isColMajor) { - kernel_shuffle[0] = 1; - kernel_shuffle[1] = 0; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 3; - kernel_shuffle[4] = 4; - } else { - kernel_shuffle[0] = 0; - kernel_shuffle[1] = 1; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 4; - kernel_shuffle[4] = 3; - } - - // Reverse the spatial dimensions. - array<bool, 5> kernel_reverse; - if (isColMajor) { - kernel_reverse[0] = false; - kernel_reverse[1] = false; - kernel_reverse[2] = true; - kernel_reverse[3] = true; - kernel_reverse[4] = true; - } else { - kernel_reverse[0] = true; - kernel_reverse[1] = true; - kernel_reverse[2] = true; - kernel_reverse[3] = false; - kernel_reverse[4] = false; + kernel_dims[1] = kernelRows; + kernel_dims[0] = kernelCols; } - DSizes<TensorIndex, NumDims> strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } return choose( Cond<internal::traits<Input>::Layout == ColMajor>(), - input.reshape(input_dims) - .contract(output_backward + output_backward.reshape(output_dims) + .contract(input .extract_volume_patches( - inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) - .reshape(pre_contract_dims), + kernelPlanes, kernelRows, kernelCols, stridePlanes, + strideRows, strideCols, 1, 1, 1, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, + padding_left, padding_right) + .reshape(pre_contract_dims) + .shuffle(shuffle_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle), - output_backward - .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, + .reshape(kernel_dims), + input + .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, + stridePlanes, strideRows, strideCols, 1, 1, 1, + padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .contract(input.reshape(input_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle)); + .shuffle(shuffle_dims) + .contract(output_backward.reshape(output_dims), contract_dims) + .reshape(kernel_dims)); } } // end namespace Eigen diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h index 46ad38fb77..87e41b89b3 100644 --- a/tensorflow/core/kernels/eigen_benchmark.h +++ b/tensorflow/core/kernels/eigen_benchmark.h @@ -76,6 +76,9 @@ class SpatialConvolutionBenchmarksSuite { void SpatialConvolutionBackwardInput(Dimensions input_dims, Dimensions filter_dims) { + using OutputBackward = TTypes<float, 4>::ConstTensor; + using InputBackward = TTypes<float, 4>::Tensor; + Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height input_dims[2], // input_width @@ -85,37 +88,37 @@ class SpatialConvolutionBenchmarksSuite { Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; - Scalar* input_data = - static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::SpatialConvolutionBackwardInput( - filter, input, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void SpatialConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes<float, 4>::ConstTensor; - using FilterGrad = TTypes<float, 4>::Tensor; + using FilterBackward = TTypes<float, 4>::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -130,7 +133,7 @@ class SpatialConvolutionBenchmarksSuite { static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -138,19 +141,19 @@ class SpatialConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, input_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: @@ -215,42 +218,45 @@ class CuboidConvolutionBenchmarksSuite { input_dims[3], // input_planes filter_dims[4]); // filter_count + using OutputBackward = TTypes<float, 5>::ConstTensor; + using InputBackward = TTypes<float, 5>::Tensor; + // Assuming that the convolution had SAME padding. Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; Eigen::Index input_planes = input_dims[3]; - Scalar* input_data = - static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::CuboidConvolutionBackwardInput( - filter, input, input_planes, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, input_planes, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void CuboidConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes<float, 5>::ConstTensor; - using FilterGrad = TTypes<float, 5>::Tensor; + using FilterBackward = TTypes<float, 5>::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -267,7 +273,7 @@ class CuboidConvolutionBenchmarksSuite { static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -275,19 +281,19 @@ class CuboidConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, output_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel( input, output_backward, filter_planes, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index 2a8308ef9a..7c2bbb8148 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -123,6 +123,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \ FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW)) @@ -130,6 +131,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK( \ @@ -138,6 +140,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ @@ -348,6 +351,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \ FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK( \ @@ -356,6 +360,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ @@ -365,6 +370,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \ FC, FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \ @@ -395,8 +401,11 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, BM_CuboidConvolutions(8, // batch size 25, 25, 25, 4, // input: height, width, panes, depth 16, 5, 5, 5, // filter: count, height, width, panes - "conv3d"); + "conv3d_depth4"); +BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); -BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); -BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index c7dbefa0b4..86146f75f4 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -123,8 +123,7 @@ class AutoTuneMap { string GetActionSummary(StringPiece action, const Parameters& params, const Config& config) { return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), - std::string(action).c_str(), - params.ToString().c_str(), + string(action).c_str(), params.ToString().c_str(), config.ToString().c_str()); } diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 066a1d603b..72581c9293 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, y->tensors.reserve(x.tensors.size()); for (const Tensor& t : x.tensors) { Tensor out_tensor; - TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor)); + AllocatorAttributes attr; + if (t.dtype() == DT_VARIANT) { + attr.set_on_host(true); + } + TF_RETURN_IF_ERROR( + c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr)); switch (out_tensor.dtype()) { #define DTYPE_CASE(dtype) \ case DataTypeToEnum<dtype>::value: \ @@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, TF_CALL_POD_TYPES(DTYPE_CASE) #undef DTYPE_CASE + + case DataTypeToEnum<Variant>::value: { + const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>(); + if (inner_x == nullptr) { + return errors::InvalidArgument("Input handle is not a list. Saw: '", + t.scalar<Variant>()().DebugString(), + "'"); + } + TensorList inner_y; + TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y)); + out_tensor.scalar<Variant>()() = std::move(inner_y); + break; + } + default: return errors::InvalidArgument( "Trying to compute zeros_like for unsupported dtype ", diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc index 10e468ce46..693ed8a8f0 100644 --- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc +++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc @@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase { // Exercises "delete_old_dirs". for (int i = 0; i < 2; ++i) { int directory_found = - Env::Default() - ->IsDirectory(std::string(io::Dirname(prefixes[i]))) - .code(); + Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code(); if (delete_old_dirs) { EXPECT_EQ(error::NOT_FOUND, directory_found); } else { diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 194a711d98..26f107f940 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts( std::unordered_set<string> retval; for (const string& node_name_and_port : node_names_and_ports) { const TensorId tid = ParseTensorName(node_name_and_port); - retval.emplace(std::string(tid.first)); + retval.emplace(tid.first); } return retval; } @@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) { const NodeDef* FindNodeDefByName(const string& input, const GraphDef& graph_def) { const TensorId tid = ParseTensorName(input); - const string name = std::string(tid.first); + const string name = string(tid.first); for (const NodeDef& node_def : graph_def.node()) { if (node_def.name() == name) { return &node_def; @@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( std::vector<DataType> data_types; std::vector<TensorShape> shapes; const TensorId tid = ParseTensorName(name_and_port); - const string node_name = std::string(tid.first); + const string node_name(tid.first); const int port = tid.second; const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); CHECK_NOTNULL(node_def); @@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType( const TensorShapeMap& tensor_shape_map, const string& node_name) { if (node_name.find(':') != string::npos) { const TensorId tid = ParseTensorName(node_name); - return GetTensorShapeType(tensor_shape_map, std::string(tid.first), - tid.second); + return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second); } else { return GetTensorShapeType(tensor_shape_map, node_name, 0); } @@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( const TensorId tid = ParseTensorName(name); CHECK_EQ(tensor_shape_map->count(name), 0); tensor_shape_map->emplace( - std::string(tid.first), + string(tid.first), std::make_pair(tid.second, std::make_pair(tensor.dtype(), tensor.shape()))); } @@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::vector<NodeBuilder::NodeOut> node_out_list; for (const string& input : inputs) { const TensorId tid = ParseTensorName(input); - Node* node = FindMutableNodeByName(std::string(tid.first), graph); + Node* node = FindMutableNodeByName(string(tid.first), graph); CHECK_NOTNULL(node); node_out_list.emplace_back(node, tid.second); } @@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (const string& subgraph_input : std::get<1>(cluster)) { const TensorId tid = ParseTensorName(subgraph_input); - const string subgraph_input_name = std::string(tid.first); + const string subgraph_input_name(tid.first); const int subgraph_input_port = tid.second; const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); CHECK_NOTNULL(node_def); @@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::deque<const Node*> queue; for (const string& output : border_outputs) { const TensorId tid = ParseTensorName(output); - const string& output_node_name = std::string(tid.first); + const string output_node_name(tid.first); for (const Node* node : graph.nodes()) { if (output_node_name == node->name()) { queue.push_back(node); @@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (int j = 0; j < border_outputs.size(); ++j) { const string& output = border_outputs.at(j); const TensorId tid = ParseTensorName(output); - const string output_name = std::string(tid.first); + const string output_name(tid.first); Node* src_node = edge->src(); if (src_node != nullptr && src_node->name() == output_name && edge->src_output() == tid.second) { @@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( // RemoteFusedGraphExecuteOpNode for (const string& output : outputs) { const TensorId output_tid = ParseTensorName(output); - const string output_name = std::string(output_tid.first); + const string output_name(output_tid.first); for (size_t i = 0; i < border_outputs.size(); ++i) { const TensorId subgraph_output_tid = ParseTensorName(border_outputs.at(i)); - const string& subgraph_output_name = - std::string(subgraph_output_tid.first); + const string subgraph_output_name(subgraph_output_tid.first); if (output_name == subgraph_output_name) { LOG(INFO) << "As graph output and subgraph output are same, " << "the graph output node is replaced by identity node"; @@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( GraphDef* graph_def) { const TensorId tid = ParseTensorName(input); CHECK_EQ(0, tid.second); - const string node_name = std::string(tid.first); + const string node_name(tid.first); for (NodeDef& node : *graph_def->mutable_node()) { if (node.name() != node_name) { continue; diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index e335e38bdc..82546d581a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context, // If we cannot find a cached reader we will allocate our own. std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; - const checkpoint::TensorSliceReader* reader = - context->slice_reader_cache()->GetReader(file_pattern, open_func, - preferred_shard); + const checkpoint::TensorSliceReader* reader = nullptr; + + if (context->slice_reader_cache()) { + reader = context->slice_reader_cache()->GetReader(file_pattern, open_func, + preferred_shard); + } if (!reader) { allocated_reader.reset(new checkpoint::TensorSliceReader( file_pattern, open_func, preferred_shard)); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index ab4de6c815..180eb3ca34 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel { context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); if (delete_old_dirs_) { - const string& merged_dir = std::string(io::Dirname(merged_prefix)); + const string merged_dir(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { - const string& dirname = std::string(io::Dirname(input_prefix)); + const string dirname(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index 2aeafa28c4..544dca96ba 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -43,7 +43,7 @@ class StringStripOp : public OpKernel { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); str_util::RemoveWhitespaceContext(&entry); - output(i) = std::string(entry); + output(i) = string(entry); } } }; diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 632b65e9b6..2ec2651c04 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -297,7 +297,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { resource.name()); } tensor_array_name = - std::string(StringPiece(resource.name()).substr(container.size())); + string(StringPiece(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat<string>(); diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index ed2bf3e8e2..1bf46b5e46 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel { "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); const string& filename = filename_input->scalar<string>()(); - const string dir = std::string(io::Dirname(filename)); + const string dir(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index 982901a39c..d5cbe6c616 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -136,11 +136,9 @@ string FormatNodeNamesForError(const T& names) { ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); }); } -// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and -// `^^key:value^^` in a follow-on CL. // LINT.IfChange inline string FormatColocationNodeForError(const string& name) { - return strings::StrCat("^^colocation_node:", name, "^^"); + return strings::StrCat("{{colocation_node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) template <typename T> diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index c18dc9ad1a..2d622dc229 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -13,674 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage -// for sequences of length <= N are provided inline without requiring -// any heap allocation. Typically N is very small (e.g., 4) so that -// sequences that are expected to be short do not require allocations. -// -// Only some of the std::vector<> operations are currently implemented. -// Other operations may be added as needed to facilitate migrating -// code that uses std::vector<> to InlinedVector<>. -// -// NOTE: If you want an inlined version to replace use of a -// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS> -// in util/bitmap/inlined_bitvector.h -// -// TODO(billydonahue): change size_t to size_type where appropriate. - #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include <stddef.h> -#include <stdlib.h> -#include <string.h> -#include <sys/types.h> -#include <algorithm> -#include <cstddef> -#include <iterator> -#include <memory> -#include <type_traits> -#include <vector> - -#include "tensorflow/core/lib/gtl/manual_constructor.h" -#include "tensorflow/core/platform/byte_order.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" +#include "absl/container/inlined_vector.h" +// TODO(kramerb): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include <initializer_list> // NOLINT(build/include_order) - namespace tensorflow { namespace gtl { -template <typename T, int N> -class InlinedVector { - public: - typedef T value_type; - typedef T* pointer; - typedef const T* const_pointer; - typedef T& reference; - typedef const T& const_reference; - typedef size_t size_type; - typedef std::ptrdiff_t difference_type; - typedef pointer iterator; - typedef const_pointer const_iterator; - - // Create an empty vector - InlinedVector(); - - // Create a vector with n copies of value_type(). - explicit InlinedVector(size_t n); - - // Create a vector with n copies of elem - InlinedVector(size_t n, const value_type& elem); - - // Create and initialize with the elements [range_start .. range_end). - // The unused enable_if argument restricts this constructor so that it is - // elided when value_type is an integral type. This prevents ambiguous - // interpretation between a call to this constructor with two integral - // arguments and a call to the preceding (n, elem) constructor. - template <typename InputIterator> - InlinedVector( - InputIterator range_start, InputIterator range_end, - typename std::enable_if<!std::is_integral<InputIterator>::value>::type* = - NULL) { - InitRep(); - AppendRange(range_start, range_end); - } - - InlinedVector(std::initializer_list<value_type> init) { - InitRep(); - AppendRange(init.begin(), init.end()); - } - - InlinedVector(const InlinedVector& v); - - ~InlinedVector() { clear(); } - - InlinedVector& operator=(const InlinedVector& v) { - // Optimized to avoid reallocation. - // Prefer reassignment to copy construction for elements. - const size_t s = size(); - const size_t vs = v.size(); - if (s < vs) { // grow - reserve(vs); - if (s) std::copy(v.begin(), v.begin() + s, begin()); - std::copy(v.begin() + s, v.end(), std::back_inserter(*this)); - } else { // maybe shrink - erase(begin() + vs, end()); - std::copy(v.begin(), v.end(), begin()); - } - return *this; - } - - size_t size() const { return size_internal(); } - - bool empty() const { return (size() == 0); } - - // Return number of elements that can be stored in vector - // without requiring a reallocation of underlying memory - size_t capacity() const { - if (is_inline()) { - return kFit; - } else { - return static_cast<size_t>(1) << u_.data[kSize - 2]; - } - } - - // Return a pointer to the underlying array. - // Only result[0,size()-1] are defined. - pointer data() { - if (is_inline()) { - return reinterpret_cast<T*>(u_.data); - } else { - return outofline_pointer(); - } - } - const_pointer data() const { - return const_cast<InlinedVector<T, N>*>(this)->data(); - } - - // Remove all elements - void clear() { - DiscardStorage(); - u_.data[kSize - 1] = 0; - } - - // Return the ith element - // REQUIRES: 0 <= i < size() - const value_type& at(size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - const value_type& operator[](size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - - // Return a non-const reference to the ith element - // REQUIRES: 0 <= i < size() - value_type& at(size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - value_type& operator[](size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - - value_type& back() { - DCHECK(!empty()); - return at(size() - 1); - } - - const value_type& back() const { - DCHECK(!empty()); - return at(size() - 1); - } - - value_type& front() { - DCHECK(!empty()); - return at(0); - } - - const value_type& front() const { - DCHECK(!empty()); - return at(0); - } - - // Append a T constructed with args to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - template <typename... Args> - void emplace_back(Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - if (s < capacity()) { - new (data() + s) T(std::forward<Args>(args)...); - set_size_internal(s + 1); - } else { - EmplaceBackSlow(std::forward<Args>(args)...); - } - } - - // Append t to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - void push_back(const value_type& t) { emplace_back(t); } - void push_back(value_type&& t) { emplace_back(std::move(t)); } - - inline void pop_back() { - DCHECK(!empty()); - const size_t s = size(); - Destroy(data() + s - 1, 1); - set_size_internal(s - 1); - } - - // Resizes the vector to contain "n" elements. - // If "n" is smaller than the initial size, extra elements are destroyed. - // If "n" is larger than the initial size, enough copies of "elem" - // are appended to increase the size to "n". If "elem" is omitted, - // new elements are value-initialized. - void resize(size_t n) { Resize<ValueInit>(n, nullptr); } - void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); } - - iterator begin() { return data(); } - const_iterator begin() const { return data(); } - - iterator end() { return data() + size(); } - const_iterator end() const { return data() + size(); } - - iterator insert(iterator pos, const value_type& v); - - iterator erase(iterator pos) { - DCHECK_LT(pos, end()); - DCHECK_GE(pos, begin()); - std::copy(pos + 1, end(), pos); - pop_back(); - return pos; - } - - iterator erase(iterator first, iterator last); - - // Enlarges the underlying representation so it can hold at least - // "n" elements without reallocation. - // Does not change size() or the actual contents of the vector. - void reserve(size_t n) { - if (n > capacity()) { - // Make room for new elements - Grow<Move>(n); - } - } - - // Swap the contents of *this with other. - // REQUIRES: value_type is swappable and copyable. - void swap(InlinedVector& other); - - private: - // Representation can either be inlined or out-of-line. - // In either case, at least sizeof(void*) + 8 bytes are available. - // - // Inlined: - // Last byte holds the length. - // First (length*sizeof(T)) bytes stores the elements. - // Outlined: - // Last byte holds kSentinel. - // Second-last byte holds lg(capacity) - // Preceding 6 bytes hold size. - // First sizeof(T*) bytes hold pointer. - - // Compute rep size. - static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag - static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align - - // See how many fit T we can fit inside kSize, but no more than 254 - // since 255 is used as sentinel tag for out-of-line allocation. - static const unsigned int kSentinel = 255; - static const size_t kFit1 = (kSize - 1) / sizeof(T); - static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1; - - union { - unsigned char data[kSize]; - // Force data to be aligned enough for a pointer. - T* unused_aligner; - } u_; - - inline void InitRep() { u_.data[kSize - 1] = 0; } - inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; } - - inline T* outofline_pointer() const { - T* ptr; - memcpy(&ptr, &u_.data[0], sizeof(ptr)); - return ptr; - } - - inline void set_outofline_pointer(T* p) { - memcpy(&u_.data[0], &p, sizeof(p)); - } - - inline uint64_t outofline_word() const { - uint64_t word; - memcpy(&word, &u_.data[kSize - 8], sizeof(word)); - return word; - } - - inline void set_outofline_word(uint64_t w) { - memcpy(&u_.data[kSize - 8], &w, sizeof(w)); - } - - inline size_t size_internal() const { - uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]); - if (s != kSentinel) { - return static_cast<size_t>(s); - } else { - const uint64_t word = outofline_word(); - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - return static_cast<size_t>(word & 0xffffffffffffull); - } else { - // The sentinel and capacity bits are least-significant bits in word. - return static_cast<size_t>(word >> 16); - } - } - } - - void set_size_internal(size_t n) { - if (is_inline()) { - DCHECK_LT(n, kSentinel); - u_.data[kSize - 1] = static_cast<unsigned char>(n); - } else { - uint64_t word; - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - word = (static_cast<uint64_t>(n) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) | - (static_cast<uint64_t>(kSentinel) << 56)); - } else { - // The sentinel and capacity bits are least-significant bits in word. - word = ((static_cast<uint64_t>(n) << 16) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) | - (static_cast<uint64_t>(kSentinel))); - } - set_outofline_word(word); - DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n; - } - } - - void DiscardStorage() { - T* base = data(); - size_t n = size(); - Destroy(base, n); - if (!is_inline()) { - port::Free(base); - } - } - - template <typename... Args> - void EmplaceBackSlow(Args&&... args) { - const size_t s = size(); - DCHECK_EQ(s, capacity()); - Grow<Move, Construct>(s + 1, std::forward<Args>(args)...); - set_size_internal(s + 1); - } - - // Movers for Grow - // Does nothing. - static void Nop(T* src, size_t n, T* dst) {} - - // Moves srcs[0,n-1] contents to dst[0,n-1]. - static void Move(T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(std::move(*(src + i))); - } - } - - // Initializers for Resize. - // Initializes dst[0,n-1] with empty constructor. - static void ValueInit(const T*, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(); - } - } - - // Initializes dst[0,n-1] with copies of *src. - static void Fill(const T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(*src); - } - } - - void Destroy(T* src, int n) { - if (!std::is_trivially_destructible<T>::value) { - for (int i = 0; i < n; i++) { - (src + i)->~T(); - } - } - } - - // Initialization methods for Grow. - // 1) Leave uninitialized memory. - struct Uninitialized { - void operator()(T*) const {} - }; - // 2) Construct a T with args at not-yet-initialized memory pointed by dst. - struct Construct { - template <class... Args> - void operator()(T* dst, Args&&... args) const { - new (dst) T(std::forward<Args>(args)...); - } - }; - - // Grow so that capacity >= n. Uses Mover to move existing elements - // to new buffer, and possibly initialize the new element according - // to InitType. - // We pass the InitType and Mover as template arguments so that - // this code compiles even if T does not support copying or default - // construction. - template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized, - class... Args> - void Grow(size_t n, Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - - // Compute new capacity by repeatedly doubling current capacity - size_t target = 1; - size_t target_lg = 0; - while (target < kFit || target < n) { - // TODO(psrc): Check and avoid overflow? - target_lg++; - target <<= 1; - } - - T* src = data(); - T* dst = static_cast<T*>(port::Malloc(target * sizeof(T))); - - // Need to copy elem before discarding src since it might alias src. - InitType{}(dst + s, std::forward<Args>(args)...); - Mover(src, s, dst); - DiscardStorage(); - - u_.data[kSize - 1] = kSentinel; - u_.data[kSize - 2] = static_cast<unsigned char>(target_lg); - set_size_internal(s); - DCHECK_EQ(capacity(), target); - set_outofline_pointer(dst); - } - - // Resize to size n. Any new elements are initialized by passing - // elem and the destination to Initializer. We pass the Initializer - // as a template argument so that this code compiles even if T does - // not support copying. - template <void(Initializer)(const T*, size_t, T*)> - void Resize(size_t n, const T* elem) { - size_t s = size(); - if (n <= s) { - Destroy(data() + n, s - n); - set_size_internal(n); - return; - } - reserve(n); - DCHECK_GE(capacity(), n); - set_size_internal(n); - Initializer(elem, n - s, data() + s); - } - - template <typename Iter> - void AppendRange(Iter first, Iter last, std::input_iterator_tag); - - // Faster path for forward iterators. - template <typename Iter> - void AppendRange(Iter first, Iter last, std::forward_iterator_tag); - - template <typename Iter> - void AppendRange(Iter first, Iter last); -}; - -// Provide linkage for constants. -template <typename T, int N> -const size_t InlinedVector<T, N>::kSizeUnaligned; -template <typename T, int N> -const size_t InlinedVector<T, N>::kSize; -template <typename T, int N> -const unsigned int InlinedVector<T, N>::kSentinel; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit1; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit; - -template <typename T, int N> -inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) { - a.swap(b); -} - -template <typename T, int N> -inline bool operator==(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); -} - -template <typename T, int N> -inline bool operator!=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a == b); -} - -template <typename T, int N> -inline bool operator<(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); -} - -template <typename T, int N> -inline bool operator>(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return b < a; -} - -template <typename T, int N> -inline bool operator<=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(b < a); -} - -template <typename T, int N> -inline bool operator>=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a < b); -} - -// ======================================== -// Implementation - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector() { - InitRep(); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Must use Nop in case T is not copyable - } - set_size_internal(n); - ValueInit(nullptr, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Can use Nop since we know we have nothing to copy - } - set_size_internal(n); - Fill(&elem, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) { - InitRep(); - *this = v; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert( - iterator pos, const value_type& v) { - DCHECK_GE(pos, begin()); - DCHECK_LE(pos, end()); - if (pos == end()) { - push_back(v); - return end() - 1; - } - size_t s = size(); - size_t idx = std::distance(begin(), pos); - if (s == capacity()) { - Grow<Move>(s + 1); - } - CHECK_LT(s, capacity()); - pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. - Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1] - std::copy_backward(pos, data() + s - 1, data() + s); - *pos = v; - - set_size_internal(s + 1); - return pos; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase( - iterator first, iterator last) { - DCHECK_LE(begin(), first); - DCHECK_LE(first, last); - DCHECK_LE(last, end()); - - size_t s = size(); - ptrdiff_t erase_gap = std::distance(first, last); - std::copy(last, data() + s, first); - Destroy(data() + s - erase_gap, erase_gap); - set_size_internal(s - erase_gap); - return first; -} - -template <typename T, int N> -void InlinedVector<T, N>::swap(InlinedVector& other) { - using std::swap; // Augment ADL with std::swap. - if (&other == this) { - return; - } - - InlinedVector* a = this; - InlinedVector* b = &other; - - const bool a_inline = a->is_inline(); - const bool b_inline = b->is_inline(); - - if (!a_inline && !b_inline) { - // Just swap the top-level representations. - T* aptr = a->outofline_pointer(); - T* bptr = b->outofline_pointer(); - a->set_outofline_pointer(bptr); - b->set_outofline_pointer(aptr); - - uint64_t aword = a->outofline_word(); - uint64_t bword = b->outofline_word(); - a->set_outofline_word(bword); - b->set_outofline_word(aword); - return; - } - - // Make a the larger of the two to reduce number of cases. - size_t a_size = a->size(); - size_t b_size = b->size(); - if (a->size() < b->size()) { - swap(a, b); - swap(a_size, b_size); - } - DCHECK_GE(a_size, b_size); - - if (b->capacity() < a_size) { - b->Grow<Move>(a_size); - } - - // One is inline and one is not. - // 'a' is larger. Swap the elements up to the smaller array size. - std::swap_ranges(a->data(), a->data() + b_size, b->data()); - std::uninitialized_copy(a->data() + b_size, a->data() + a_size, - b->data() + b_size); - Destroy(a->data() + b_size, a_size - b_size); - a->set_size_internal(b_size); - b->set_size_internal(a_size); - DCHECK_EQ(b->size(), a_size); - DCHECK_EQ(a->size(), b_size); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::input_iterator_tag) { - std::copy(first, last, std::back_inserter(*this)); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::forward_iterator_tag) { - typedef typename std::iterator_traits<Iter>::difference_type Length; - Length length = std::distance(first, last); - size_t s = size(); - reserve(s + length); - std::uninitialized_copy_n(first, length, data() + s); - set_size_internal(s + length); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) { - typedef typename std::iterator_traits<Iter>::iterator_category IterTag; - AppendRange(first, last, IterTag()); -} +using absl::InlinedVector; } // namespace gtl } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc deleted file mode 100644 index 2721885c4a..0000000000 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ /dev/null @@ -1,898 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/gtl/inlined_vector.h" - -#include <list> -#include <memory> -#include <string> -#include <vector> - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -typedef tensorflow::gtl::InlinedVector<int, 8> IntVec; - -// A type that counts number of live occurrences of the type -static int64 instances = 0; -class Instance { - public: - int value_; - explicit Instance(int x) : value_(x) { instances++; } - Instance(const Instance& x) : value_(x.value_) { instances++; } - ~Instance() { instances--; } - - friend inline void swap(Instance& a, Instance& b) { - using std::swap; - swap(a.value_, b.value_); - } - - friend std::ostream& operator<<(std::ostream& o, const Instance& v) { - return o << "[value:" << v.value_ << "]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec; - -// A simple reference counted class to make sure that the proper elements are -// destroyed in the erase(begin, end) test. -class RefCounted { - public: - RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } - - RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { - VLOG(5) << "[RefCounted: copy" - << " from count @" << v.count_ << "]"; - Ref(); - } - - ~RefCounted() { - Unref(); - count_ = nullptr; - } - - friend void swap(RefCounted& a, RefCounted& b) { - using std::swap; - swap(a.value_, b.value_); - swap(a.count_, b.count_); - } - - RefCounted& operator=(RefCounted v) { - using std::swap; - swap(*this, v); - return *this; - } - - void Ref() const { - CHECK(count_ != nullptr); - ++(*count_); - VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - void Unref() const { - --(*count_); - CHECK_GE(*count_, 0); - VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - int count() const { return *count_; } - - friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { - return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; - } - - int value_; - int* count_; -}; - -typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec; - -// A class with a vtable pointer -class Dynamic { - public: - virtual ~Dynamic() {} - - friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { - return o << "[Dynamic]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec; - -// Append 0..len-1 to *v -static void Fill(IntVec* v, int len, int offset = 0) { - for (int i = 0; i < len; i++) { - v->push_back(i + offset); - } -} - -static IntVec Fill(int len, int offset = 0) { - IntVec v; - Fill(&v, len, offset); - return v; -} - -TEST(IntVec, SimpleOps) { - for (int len = 0; len < 20; len++) { - IntVec v; - const IntVec& cv = v; // const alias - - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - EXPECT_EQ(v.begin(), v.data()); - EXPECT_EQ(cv.begin(), cv.data()); - - int counter = 0; - for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - counter = 0; - for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - if (len > 0) { - EXPECT_EQ(0, v.front()); - EXPECT_EQ(len - 1, v.back()); - v.pop_back(); - EXPECT_EQ(len - 1, v.size()); - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, Erase) { - for (int len = 1; len < 20; len++) { - for (int i = 0; i < len; ++i) { - IntVec v; - Fill(&v, len); - v.erase(v.begin() + i); - EXPECT_EQ(len - 1, v.size()); - for (int j = 0; j < i; ++j) { - EXPECT_EQ(j, v[j]); - } - for (int j = i; j < len - 1; ++j) { - EXPECT_EQ(j + 1, v[j]); - } - } - } -} - -// At the end of this test loop, the elements between [erase_begin, erase_end) -// should have reference counts == 0, and all others elements should have -// reference counts == 1. -TEST(RefCountedVec, EraseBeginEnd) { - for (int len = 1; len < 20; ++len) { - for (int erase_begin = 0; erase_begin < len; ++erase_begin) { - for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { - std::vector<int> counts(len, 0); - RefCountedVec v; - for (int i = 0; i < len; ++i) { - v.push_back(RefCounted(i, &counts[i])); - } - - int erase_len = erase_end - erase_begin; - - v.erase(v.begin() + erase_begin, v.begin() + erase_end); - - EXPECT_EQ(len - erase_len, v.size()); - - // Check the elements before the first element erased. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(i, v[i].value_); - } - - // Check the elements after the first element erased. - for (size_t i = erase_begin; i < v.size(); ++i) { - EXPECT_EQ(i + erase_len, v[i].value_); - } - - // Check that the elements at the beginning are preserved. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(1, counts[i]); - } - - // Check that the erased elements are destroyed - for (int i = erase_begin; i < erase_end; ++i) { - EXPECT_EQ(0, counts[i]); - } - - // Check that the elements at the end are preserved. - for (int i = erase_end; i < len; ++i) { - EXPECT_EQ(1, counts[i]); - } - } - } - } -} - -struct NoDefaultCtor { - explicit NoDefaultCtor(int) {} -}; -struct NoCopy { - NoCopy() {} - NoCopy(const NoCopy&) = delete; -}; -struct NoAssign { - NoAssign() {} - NoAssign& operator=(const NoAssign&) = delete; -}; -struct MoveOnly { - MoveOnly() {} - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; -TEST(InlinedVectorTest, NoDefaultCtor) { - tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2)); - (void)v; -} -TEST(InlinedVectorTest, NoCopy) { - tensorflow::gtl::InlinedVector<NoCopy, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, NoAssign) { - tensorflow::gtl::InlinedVector<NoAssign, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, MoveOnly) { - gtl::InlinedVector<MoveOnly, 2> v; - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); -} - -TEST(IntVec, Insert) { - for (int len = 0; len < 20; len++) { - for (int pos = 0; pos <= len; pos++) { - IntVec v; - Fill(&v, len); - v.insert(v.begin() + pos, 9999); - EXPECT_EQ(v.size(), len + 1); - for (int i = 0; i < pos; i++) { - EXPECT_EQ(v[i], i); - } - EXPECT_EQ(v[pos], 9999); - for (size_t i = pos + 1; i < v.size(); i++) { - EXPECT_EQ(v[i], i - 1); - } - } - } -} - -TEST(RefCountedVec, InsertConstructorDestructor) { - // Make sure the proper construction/destruction happen during insert - // operations. - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - for (int pos = 0; pos <= len; pos++) { - SCOPED_TRACE(pos); - std::vector<int> counts(len, 0); - int inserted_count = 0; - RefCountedVec v; - for (int i = 0; i < len; ++i) { - SCOPED_TRACE(i); - v.push_back(RefCounted(i, &counts[i])); - } - - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - - RefCounted insert_element(9999, &inserted_count); - EXPECT_EQ(1, inserted_count); - v.insert(v.begin() + pos, insert_element); - EXPECT_EQ(2, inserted_count); - // Check that the elements at the end are preserved. - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - EXPECT_EQ(2, inserted_count); - } - } -} - -TEST(IntVec, Resize) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - // Try resizing up and down by k elements - static const int kResizeElem = 1000000; - for (int k = 0; k < 10; k++) { - // Enlarging resize - v.resize(len + k, kResizeElem); - EXPECT_EQ(len + k, v.size()); - EXPECT_LE(len + k, v.capacity()); - for (int i = 0; i < len + k; i++) { - if (i < len) { - EXPECT_EQ(i, v[i]); - } else { - EXPECT_EQ(kResizeElem, v[i]); - } - } - - // Shrinking resize - v.resize(len, kResizeElem); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, InitWithLength) { - for (int len = 0; len < 20; len++) { - IntVec v(len, 7); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(7, v[i]); - } - } -} - -TEST(IntVec, CopyConstructorAndAssignment) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - IntVec v2(v); - EXPECT_EQ(v, v2); - - for (int start_len = 0; start_len < 20; start_len++) { - IntVec v3; - Fill(&v3, start_len, 99); // Add dummy elements that should go away - v3 = v; - EXPECT_EQ(v, v3); - } - } -} - -TEST(OverheadTest, Storage) { - // Check for size overhead. - using tensorflow::gtl::InlinedVector; - EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>)); - EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>)); - - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>)); - EXPECT_EQ(2 * sizeof(char*), - sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>)); - EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>)); -} - -TEST(IntVec, Clear) { - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - IntVec v; - Fill(&v, len); - v.clear(); - EXPECT_EQ(0, v.size()); - EXPECT_EQ(v.begin(), v.end()); - } -} - -TEST(IntVec, Reserve) { - for (size_t len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - for (size_t newlen = 0; newlen < 100; newlen++) { - const int* start_rep = v.data(); - v.reserve(newlen); - const int* final_rep = v.data(); - if (newlen <= len) { - EXPECT_EQ(start_rep, final_rep); - } - EXPECT_LE(newlen, v.capacity()); - - // Filling up to newlen should not change rep - while (v.size() < newlen) { - v.push_back(0); - } - EXPECT_EQ(final_rep, v.data()); - } - } -} - -template <typename T> -static std::vector<typename T::value_type> Vec(const T& src) { - std::vector<typename T::value_type> result; - for (const auto& elem : src) { - result.push_back(elem); - } - return result; -} - -TEST(IntVec, SelfRefPushBack) { - std::vector<string> std_v; - tensorflow::gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(std_v, Vec(v)); - - v.push_back(v.back()); - std_v.push_back(std_v.back()); - } - EXPECT_EQ(std_v, Vec(v)); -} - -TEST(IntVec, SelfRefPushBackWithMove) { - std::vector<string> std_v; - gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(v.back(), std_v.back()); - - v.push_back(std::move(v.back())); - std_v.push_back(std::move(std_v.back())); - } - EXPECT_EQ(v.back(), std_v.back()); -} - -TEST(IntVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - SCOPED_TRACE(l1); - for (int l2 = 0; l2 < 20; l2++) { - SCOPED_TRACE(l2); - IntVec a = Fill(l1, 0); - IntVec b = Fill(l2, 100); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(i, b[i]); - } - for (int i = 0; i < l2; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(100 + i, a[i]); - } - } - } -} - -TEST(InstanceVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - for (int l2 = 0; l2 < 20; l2++) { - InstanceVec a, b; - for (int i = 0; i < l1; i++) a.push_back(Instance(i)); - for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); - EXPECT_EQ(l1 + l2, instances); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1 + l2, instances); - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - EXPECT_EQ(i, b[i].value_); - } - for (int i = 0; i < l2; i++) { - EXPECT_EQ(100 + i, a[i].value_); - } - } - } -} - -TEST(IntVec, EqualAndNotEqual) { - IntVec a, b; - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - a.push_back(3); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b.push_back(3); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b.push_back(7); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.push_back(6); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.clear(); - b.clear(); - for (int i = 0; i < 100; i++) { - a.push_back(i); - b.push_back(i); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b[i] = b[i] + 1; - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b[i] = b[i] - 1; // Back to before - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - } -} - -TEST(IntVec, RelationalOps) { - IntVec a, b; - EXPECT_FALSE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_FALSE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_TRUE(b <= a); - EXPECT_TRUE(a >= b); - EXPECT_TRUE(b >= a); - b.push_back(3); - EXPECT_TRUE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_TRUE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_FALSE(b <= a); - EXPECT_FALSE(a >= b); - EXPECT_TRUE(b >= a); -} - -TEST(InstanceVec, CountConstructorsDestructors) { - const int start = instances; - for (int len = 0; len < 20; len++) { - InstanceVec v; - for (int i = 0; i < len; i++) { - v.push_back(Instance(i)); - } - EXPECT_EQ(start + len, instances); - - { // Copy constructor should create 'len' more instances. - InstanceVec v_copy(v); - EXPECT_EQ(start + len + len, instances); - } - EXPECT_EQ(start + len, instances); - - // Enlarging resize() must construct some objects - v.resize(len + 10, Instance(100)); - EXPECT_EQ(start + len + 10, instances); - - // Shrinking resize() must destroy some objects - v.resize(len, Instance(100)); - EXPECT_EQ(start + len, instances); - - // reserve() must not increase the number of initialized objects - v.reserve(len + 1000); - EXPECT_EQ(start + len, instances); - - // pop_back() and erase() must destroy one object - if (len > 0) { - v.pop_back(); - EXPECT_EQ(start + len - 1, instances); - if (!v.empty()) { - v.erase(v.begin()); - EXPECT_EQ(start + len - 2, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { - const int start = instances; - for (int len = 0; len < 20; len++) { - for (int longorshort = 0; longorshort <= 1; ++longorshort) { - InstanceVec longer, shorter; - for (int i = 0; i < len; i++) { - longer.push_back(Instance(i)); - shorter.push_back(Instance(i)); - } - longer.push_back(Instance(len)); - EXPECT_EQ(start + len + len + 1, instances); - - if (longorshort) { - shorter = longer; - EXPECT_EQ(start + (len + 1) + (len + 1), instances); - } else { - longer = shorter; - EXPECT_EQ(start + len + len, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(RangedConstructor, SimpleType) { - std::vector<int> source_v = {4, 5, 6, 7}; - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end()); - tensorflow::gtl::InlinedVector<int, 4> empty4; - EXPECT_EQ(4, v.size()); - EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(4, v[0]); - EXPECT_EQ(5, v[1]); - EXPECT_EQ(6, v[2]); - EXPECT_EQ(7, v[3]); - - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<int, 2> empty2; - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty2.capacity(), realloc_v.capacity()); - EXPECT_EQ(4, realloc_v[0]); - EXPECT_EQ(5, realloc_v[1]); - EXPECT_EQ(6, realloc_v[2]); - EXPECT_EQ(7, realloc_v[3]); -} - -TEST(RangedConstructor, ComplexType) { - // We also use a list here to pass a different flavor of iterator (e.g. not - // random-access). - std::list<Instance> source_v = {Instance(0)}; - - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<Instance, 1> empty1; - EXPECT_EQ(1, v.size()); - EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(0, v[0].value_); - - std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2), - Instance(3)}; - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(), - source_v2.end()); - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty1.capacity(), realloc_v.capacity()); - EXPECT_EQ(0, realloc_v[0].value_); - EXPECT_EQ(1, realloc_v[1].value_); - EXPECT_EQ(2, realloc_v[2].value_); - EXPECT_EQ(3, realloc_v[3].value_); -} - -TEST(RangedConstructor, ElementsAreConstructed) { - std::vector<string> source_v = {"cat", "dog"}; - - // Force expansion and re-allocation of v. Ensures that when the vector is - // expanded that new elements are constructed. - tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end()); - EXPECT_EQ("cat", v[0]); - EXPECT_EQ("dog", v[1]); -} - -TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { - auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_EQ(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { - auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_LE(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, DisparateTypesInList) { - EXPECT_EQ((std::vector<int>{-7, 8}), - Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL})); - - EXPECT_EQ( - (std::vector<string>{"foo", "bar"}), - Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")})); -} - -TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { - tensorflow::gtl::InlinedVector<Instance, 1> empty; - auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)}; - EXPECT_EQ(1, vec.size()); - EXPECT_EQ(empty.capacity(), vec.capacity()); - EXPECT_EQ(0, vec[0].value_); -} - -TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { - auto vec = - tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)}; - EXPECT_EQ(2, vec.size()); - EXPECT_LE(2, vec.capacity()); - EXPECT_EQ(0, vec[0].value_); - EXPECT_EQ(1, vec[1].value_); -} - -TEST(DynamicVec, DynamicVecCompiles) { - DynamicVec v; - (void)v; -} - -static void BM_InlinedVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - IntVec v; - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); - -static void BM_InlinedVectorFillRange(int iters, int len) { - std::unique_ptr<int[]> ia(new int[len]); - for (int j = 0; j < len; j++) { - ia[j] = j; - } - for (int i = 0; i < iters; i++) { - IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); - -static void BM_StdVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - std::vector<int> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_StdVectorFill)->Range(0, 1024); - -bool StringRepresentedInline(string s) { - const char* chars = s.data(); - string s1 = std::move(s); - return s1.data() != chars; -} - -static void BM_InlinedVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - gtl::InlinedVector<string, 8> v; - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); -} -BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024); - -static void BM_StdVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - std::vector<string> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); - // The purpose of the benchmark is to verify that inlined vector is - // efficient when moving is more efficient than copying. To do so, we - // use strings that are larger than the small string optimization. - CHECK(!StringRepresentedInline(strings[0])); -} -BENCHMARK(BM_StdVectorFillString)->Range(0, 1024); - -namespace { -struct Buffer { // some arbitrary structure for benchmarking. - char* base; - int length; - int capacity; - void* user_data; -}; -} // anonymous namespace - -static void BM_InlinedVectorTenAssignments(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec; - - BufferVec src; - src.resize(len); - - iters *= 10; - BufferVec dst; - for (int i = 0; i < iters; i++) { - dst = src; - } -} -BENCHMARK(BM_InlinedVectorTenAssignments) - ->Arg(0) - ->Arg(1) - ->Arg(2) - ->Arg(3) - ->Arg(4) - ->Arg(20); - -static void BM_CreateFromInitializerList(int iters) { - for (; iters > 0; iters--) { - tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3}; - (void)x[0]; - } -} -BENCHMARK(BM_CreateFromInitializerList); - -namespace { - -struct LargeSwappable { - LargeSwappable() : d_(1024, 17) {} - ~LargeSwappable() {} - LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} - - friend void swap(LargeSwappable& a, LargeSwappable& b) { - using std::swap; - swap(a.d_, b.d_); - } - - LargeSwappable& operator=(LargeSwappable o) { - using std::swap; - swap(*this, o); - return *this; - } - - std::vector<int> d_; -}; - -} // namespace - -static void BM_LargeSwappableElements(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec; - Vec a(len); - Vec b; - while (--iters >= 0) { - using std::swap; - swap(a, b); - } -} -BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); - -} // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cb0cb46752..9836f784ab 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -29381,6 +29381,49 @@ op { } } op { + name: "MapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "MapDefun" input_arg { name: "arguments" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f03639e833..1a5ad8f421 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -198,6 +198,7 @@ REGISTER_OP("MapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ParallelMapDataset") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 4419f93d0c..28b25fdeae 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -14542,6 +14542,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "MapDefun" diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index a1be4aacce..5e1eabee5b 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, .StopCapture() .OneLiteral(": ") .GetResult(&value, &name)) { - string str_value = std::string(value); + string str_value(value); str_util::StripTrailingWhitespace(&str_value); - that->response_headers_[std::string(name)] = str_value; + that->response_headers_[string(name)] = str_value; } return size * nmemb; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 9d33787bd5..8f959c018e 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", fname); } - *bucket = std::string(bucketp); + *bucket = string(bucketp); if (bucket->empty() || *bucket == ".") { return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } str_util::ConsumePrefix(&objectp, "/"); - *object = std::string(objectp); + *object = string(objectp); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); @@ -224,7 +224,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) { for (const string& path : paths) { StringPiece subpath = io::Dirname(path); while (!subpath.empty()) { - result.emplace(std::string(subpath)); + result.emplace(string(subpath)); subpath = io::Dirname(subpath); } } @@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() { if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair<const string, const string>( - std::string(header_name), std::string(header_value))); + string(header_name), string(header_value))); VLOG(1) << "GCS additional header ENABLED. " << "Name: " << additional_header_->first << ", " @@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string& dir = std::string(io::Dirname(fixed_prefix)); + const string dir(io::Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, " doesn't match the prefix ", object_prefix)); } if (!relative_path.empty() || include_self_directory_marker) { - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); } if (++retrieved_results >= max_results) { return Status::OK(); @@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); } - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); if (++retrieved_results >= max_results) { return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index ee6ba7b041..9b85cae9b9 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson( // Send the request to the Google OAuth 2.0 server to get the token. std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); @@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 4ffa72288b..1cd0641cd3 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", grant_type); - int last_dot = std::string(assertion).find_last_of("."); - string header_dot_claim = std::string(assertion.substr(0, last_dot)); - string signature_encoded = std::string(assertion.substr(last_dot + 1)); + int last_dot = assertion.rfind('.'); + string header_dot_claim(assertion.substr(0, last_dot)); + string signature_encoded(assertion.substr(last_dot + 1)); // Check that 'signature' signs 'header_dot_claim'. diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 07b2e3426b..bb841aeab7 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -625,6 +625,7 @@ def tf_additional_lib_deps(): """Additional dependencies needed to build TF libraries.""" return [ "@com_google_absl//absl/base:base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:optional", ] + if_static( diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index da3a99565e..625d5649e6 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -390,9 +390,12 @@ message ConfigProto { message Experimental { // Task name for group resolution. string collective_group_leader = 1; - // Whether the client will format templated errors. For example, the string: - // "The node was defined on ^^node:Foo:${file}:${line}^^". - bool client_handles_error_formatting = 2; + + // We removed the flag client_handles_error_formatting. Marking the tag + // number as reserved. + // TODO(shikharagarwal): Should we just remove this tag so that it can be + // used in future for other purpose? + reserved 2; // Which executor to use, the default executor will be used // if it is an empty string or "DEFAULT" diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 1841dd998b..ae0ad27f15 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1132,7 +1132,7 @@ class BaseSession(SessionInterface): for details of the allowable fetch types. feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run` for details of the allowable feed key types. - accept_options: (Optional.) Iff `True`, the returned `Callable` will be + accept_options: (Optional.) If `True`, the returned `Callable` will be able to accept `tf.RunOptions` and `tf.RunMetadata` as optional keyword arguments `options` and `run_metadata`, respectively, with the same syntax and semantics as `tf.Session.run`, which is useful @@ -1302,9 +1302,7 @@ class BaseSession(SessionInterface): node_def = op.node_def except KeyError: pass - if (self._config is not None and - self._config.experimental.client_handles_error_formatting): - message = error_interpolation.interpolate(message, self._graph) + message = error_interpolation.interpolate(message, self._graph) raise type(e)(node_def, op, message) def _extend_graph(self): diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 459f494b48..586f4c6936 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 4) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5) @tf_export("compat.forward_compatible") diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index b0414ad655..671e5d4812 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -91,7 +91,7 @@ class IteratorTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) @@ -117,7 +117,7 @@ class IteratorTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) @@ -208,7 +208,7 @@ class IteratorTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) @@ -216,7 +216,7 @@ class IteratorTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) - with self.test_session() as sess: + with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): @@ -287,7 +287,7 @@ class IteratorTest(test.TestCase): .make_initializable_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @@ -308,7 +308,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, iterator.output_types) self.assertEqual([None], iterator.output_shapes.as_list()) - with self.test_session() as sess: + with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) @@ -380,7 +380,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) - with self.test_session() as sess: + with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) @@ -436,7 +436,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) - with self.test_session() as sess: + with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) @@ -524,7 +524,7 @@ class IteratorTest(test.TestCase): feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) - with self.test_session() as sess: + with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_int_scalar.make_one_shot_iterator().string_handle()) handle_float_vector = sess.run( @@ -687,7 +687,7 @@ class IteratorTest(test.TestCase): f=_remote_fn, target=target_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: elem = sess.run( remote_op, feed_dict={ @@ -803,16 +803,15 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator.get_next()) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) - with self.test_session() as sess: - self.assertAllEqual([1, 4], get_next()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual([9, 16], get_next()) - self.assertAllEqual([25, 36], get_next()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual([9, 16], get_next()) - self.assertAllEqual([25, 36], get_next()) - with self.assertRaises(errors.OutOfRangeError): - get_next() + self.assertAllEqual([1, 4], get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + with self.assertRaises(errors.OutOfRangeError): + get_next() @test_util.run_in_graph_and_eager_modes def testSaveRestoreMultipleIterator(self): @@ -833,19 +832,18 @@ class IteratorCheckpointingTest(test.TestCase): ) else functools.partial(self.evaluate, iterator_3.get_next()) checkpoint = checkpointable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) - with self.test_session() as sess: - self.assertAllEqual([1, 4], get_next_1()) - self.assertAllEqual(0, get_next_3()) - self.assertAllEqual(1, get_next_3()) - self.assertAllEqual(2, get_next_3()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual([1, 4], get_next_2()) - self.assertAllEqual([9, 16], get_next_2()) - self.assertAllEqual(3, get_next_3()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual([9, 16], get_next_1()) - self.assertAllEqual([1, 4], get_next_2()) - self.assertAllEqual(3, get_next_3()) + self.assertAllEqual([1, 4], get_next_1()) + self.assertAllEqual(0, get_next_3()) + self.assertAllEqual(1, get_next_3()) + self.assertAllEqual(2, get_next_3()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual([9, 16], get_next_2()) + self.assertAllEqual(3, get_next_3()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual([9, 16], get_next_1()) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual(3, get_next_3()) @test_util.run_in_graph_and_eager_modes def testRestoreExhaustedIterator(self): @@ -856,17 +854,16 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator.get_next()) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) - with self.test_session() as sess: - self.assertAllEqual(0, get_next()) - self.assertAllEqual(1, get_next()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual(2, get_next()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual(2, get_next()) - save_path = checkpoint.save(checkpoint_prefix) - checkpoint.restore(save_path).run_restore_ops(sess) - with self.assertRaises(errors.OutOfRangeError): - get_next() + self.assertAllEqual(0, get_next()) + self.assertAllEqual(1, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual(2, get_next()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual(2, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + checkpoint.restore(save_path).run_restore_ops() + with self.assertRaises(errors.OutOfRangeError): + get_next() def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() @@ -876,7 +873,7 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next() checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) for i in range(5): - with self.test_session() as sess: + with self.cached_session() as sess: checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore(sess) for j in range(2): diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 52b4320bf1..df2c9b170a 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -711,57 +711,74 @@ class MapDatasetBenchmark(test.Benchmark): def benchmarkChainOfMaps(self): chain_lengths = [0, 1, 2, 5, 10, 20, 50] for chain_length in chain_lengths: - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) - for _ in range(chain_length): - dataset = dataset.map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for _ in range(5): - sess.run(next_element.op) - deltas = [] - for _ in range(100): - start = time.time() - for _ in range(100): + for use_inter_op_parallelism in [False, True]: + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset_ops.MapDataset( + dataset, + lambda x: x, + use_inter_op_parallelism=use_inter_op_parallelism) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): sess.run(next_element.op) - end = time.time() - deltas.append(end - start) - - median_wall_time = np.median(deltas) / 100 - print("Map dataset chain length: %d Median wall time: %f" - % (chain_length, median_wall_time)) - self.report_benchmark( - iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_chain_latency_%d" % chain_length) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + print("Map dataset chain length%s: %d Median wall time: %f" % + (" (single threaded mode)" if not use_inter_op_parallelism + else "", chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_%d%s" % + (chain_length, "_single_threaded" + if not use_inter_op_parallelism else "")) def benchmarkMapFanOut(self): fan_outs = [1, 2, 5, 10, 20, 50, 100] for fan_out in fan_outs: - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors( - tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for _ in range(5): - sess.run(next_element[0].op) - deltas = [] - for _ in range(100): - start = time.time() - for _ in range(100): + for use_inter_op_parallelism in [False, True]: + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors( + tuple(0 for _ in range(fan_out))).repeat(None) + dataset = dataset_ops.MapDataset( + dataset, + lambda *xs: xs, + use_inter_op_parallelism=use_inter_op_parallelism) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): sess.run(next_element[0].op) - end = time.time() - deltas.append(end - start) - - median_wall_time = np.median(deltas) / 100 - print("Map dataset fan out: %d Median wall time: %f" - % (fan_out, median_wall_time)) - self.report_benchmark( - iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_fan_out_%d" % fan_out) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element[0].op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + print("Map dataset fan out%s: %d Median wall time: %f" % + (" (single threaded mode)" if not use_inter_op_parallelism + else "", fan_out, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_fan_out_%d%s" % + (fan_out, "_single_threaded" + if not use_inter_op_parallelism else "")) if __name__ == "__main__": diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 8c37b1871b..6205ee392e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2207,10 +2207,11 @@ def _warn_if_collections(transformation_name): class MapDataset(Dataset): """A `Dataset` that maps a function over elements in its input.""" - def __init__(self, input_dataset, map_func): + def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" super(MapDataset, self).__init__() self._input_dataset = input_dataset + self._use_inter_op_parallelism = use_inter_op_parallelism wrapped_func = StructuredFunctionWrapper( map_func, "Dataset.map()", input_dataset) @@ -2225,6 +2226,7 @@ class MapDataset(Dataset): input_t, self._map_func.captured_inputs, f=self._map_func, + use_inter_op_parallelism=self._use_inter_op_parallelism, **flat_structure(self)) @property diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index caf36b6a36..6673178ee7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -64,7 +64,7 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(fn, [0])(var)[0] grad = self.evaluate(ops.convert_to_tensor(grad)) - with context.graph_mode(), self.test_session(): + with context.graph_mode(): tf_var = array_ops.constant(var_np, dtypes.float32) tf_ind1 = array_ops.constant([0, 1]) tf_ind2 = array_ops.constant([2, 3]) @@ -79,7 +79,7 @@ class BackpropTest(test.TestCase): tf_dense_grad = math_ops.unsorted_segment_sum( tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0]) - self.assertAllClose(grad, tf_dense_grad.eval()) + self.assertAllClose(grad, self.evaluate(tf_dense_grad)) def testImplicitGradWithResourceVariable(self): x = resource_variable_ops.ResourceVariable( @@ -198,7 +198,7 @@ class BackpropTest(test.TestCase): grad = backprop.implicit_grad(f)()[0][0] opt = training.GradientDescentOptimizer(lrn_rate) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_x = array_ops.ones((batch_size), dtypes.int64) # TODO(ashankar,apassos): Change to ResourceVariable. tf_embedding = variables.Variable( @@ -941,7 +941,7 @@ class BackpropTest(test.TestCase): def testZerosCacheDoesntLeakAcrossGraphs(self): with context.graph_mode(): def get_grad(): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4)) x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4)) with backprop.GradientTape() as tape: diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index c08cf61220..1c0c4581c0 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -142,7 +142,7 @@ def _dnn_model_fn(features, dropout=None, input_layer_partitioner=None, config=None, - tpu_estimator_spec=False, + use_tpu=False, batch_norm=False): """Deep Neural Net model_fn. @@ -164,8 +164,8 @@ def _dnn_model_fn(features, input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. - tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or - or `model_fn.EstimatorSpec` instance. + use_tpu: Whether to make a DNN model able to run on TPU. Will make function + return a `_TPUEstimatorSpec` instance and disable variable partitioning. batch_norm: Whether to use batch normalization after each hidden layer. Returns: @@ -182,13 +182,15 @@ def _dnn_model_fn(features, optimizer, learning_rate=_LEARNING_RATE) num_ps_replicas = config.num_ps_replicas if config else 0 - partitioner = partitioned_variables.min_max_variable_partitioner( - max_partitions=num_ps_replicas) + partitioner = (None if use_tpu else + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas)) with variable_scope.variable_scope( 'dnn', values=tuple(six.itervalues(features)), partitioner=partitioner): input_layer_partitioner = input_layer_partitioner or ( + None if use_tpu else partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) @@ -203,7 +205,7 @@ def _dnn_model_fn(features, batch_norm=batch_norm) logits = logit_fn(features=features, mode=mode) - if tpu_estimator_spec: + if use_tpu: return head._create_tpu_estimator_spec( # pylint: disable=protected-access features=features, mode=mode, diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index a69018d00d..46bda2e621 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -15,7 +15,7 @@ """Function for interpolating formatted errors from the TensorFlow runtime. Exposes the function `interpolate` to interpolate messages with tags of the form -^^type:name:format^^. +{{type name}}. """ from __future__ import absolute_import @@ -32,7 +32,7 @@ import six from tensorflow.python.util import tf_stack _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" -_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX) +_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) _INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) @@ -48,8 +48,8 @@ def _parse_message(message): """Parses the message. Splits the message into separators and tags. Tags are named tuples - representing the string ^^type:name^^ and they are separated by - separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are + representing the string {{type name}} and they are separated by + separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are two tags and three separators. The separators are the numeric characters. Args: @@ -58,7 +58,7 @@ def _parse_message(message): Returns: (list of separator strings, list of _ParseTags). - For example, if message is "123^^node:Foo^^456" then this function + For example, if message is "123{{node Foo}}456" then this function returns (["123", "456"], [_ParseTag("node", "Foo")]) """ seps = [] @@ -276,7 +276,7 @@ def interpolate(error_message, graph): message. Returns: - The string with tags of the form ^^type:name^^ interpolated. + The string with tags of the form {{type name}} interpolated. """ seps, tags = _parse_message(error_message) subs = [] @@ -288,7 +288,7 @@ def interpolate(error_message, graph): except KeyError: op = None - msg = "^^%s:%s^^" % (t.type, t.name) + msg = "{{%s %s}}" % (t.type, t.name) if op is not None: field_dict = compute_field_dict(op) if t.type == "node": diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index a7c7bbf28b..d312b825d2 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -167,20 +167,20 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): self.assertEqual(interpolated_string, normal_string) def testOneTagWithAFakeNameResultsInPlaceholders(self): - one_tag_string = "^^node:MinusOne^^" + one_tag_string = "{{node MinusOne}}" interpolated_string = error_interpolation.interpolate( one_tag_string, self.graph) self.assertEqual(one_tag_string, interpolated_string) def testTwoTagsNoSeps(self): - two_tags_no_seps = "^^node:One^^^^node:Three^^" + two_tags_no_seps = "{{node One}}{{node Three}}" interpolated_string = error_interpolation.interpolate( two_tags_no_seps, self.graph) self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*constant_op.py:[0-9]+") def testTwoTagsWithSeps(self): - two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;" + two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" interpolated_string = error_interpolation.interpolate( two_tags_with_seps, self.graph) expected_regex = ( @@ -206,23 +206,23 @@ class InterpolateDeviceSummaryTest(test.TestCase): self.graph = self.three.graph def testNodeZeroHasNoDeviceSummaryInfo(self): - message = "^^colocation_node:zero^^" + message = "{{colocation_node zero}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No device assignments were active", result) def testNodeOneHasExactlyOneInterpolatedDevice(self): - message = "^^colocation_node:one^^" + message = "{{colocation_node one}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) def testNodeTwoHasTwoInterpolatedDevice(self): - message = "^^colocation_node:two^^" + message = "{{colocation_node two}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) self.assertEqual(2, result.count("tf.device(/cpu:0)")) def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): - message = "^^colocation_node:three^^" + message = "{{colocation_node three}}" result = error_interpolation.interpolate(message, self.graph) num_devices = result.count("tf.device") self.assertEqual(2, num_devices) @@ -256,12 +256,12 @@ class InterpolateColocationSummaryTest(test.TestCase): self.graph = node_three.graph def testNodeThreeHasColocationInterpolation(self): - message = "^^colocation_node:Three_with_one^^" + message = "{{colocation_node Three_with_one}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): - message = "^^colocation_node:Four_with_three^^" + message = "{{colocation_node Four_with_three}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(Three_with_one)", result) self.assertNotIn( @@ -269,13 +269,13 @@ class InterpolateColocationSummaryTest(test.TestCase): "Node One should not appear in Four_with_three's summary:\n%s" % result) def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): - message = "^^colocation_node:Five_with_one_with_two^^" + message = "{{colocation_node Five_with_one_with_two}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) self.assertIn("colocate_with(Two)", result) def testColocationInterpolationForNodeLackingColocation(self): - message = "^^colocation_node:One^^" + message = "{{colocation_node One}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No node-device colocations", result) self.assertNotIn("Two", result) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index b14290c203..26170b000d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): A `TensorProto`. Depending on the type, it may contain data in the "tensor_content" attribute, which is not directly useful to Python programs. To access the values you should convert the proto back to a numpy ndarray - with `tensor_util.MakeNdarray(proto)`. + with `tf.make_ndarray(proto)`. If `values` is a `TensorProto`, it is immediately returned; `dtype` and `shape` are ignored. diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b5388ad0b2..3b63e49a84 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -535,15 +535,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - if context.executing_eagerly(): - f(self, **kwargs) - ops.reset_default_graph() - else: - # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the graph key so optimizers behave. - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): - ops.get_default_graph()._graph_key = outside_graph_key + outside_executed_eagerly = context.executing_eagerly() + # Run the test in a new graph so that collections get cleared when it's + # done, but inherit the graph key so optimizers behave. + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + if outside_executed_eagerly: + with context.eager_mode(): + f(self, **kwargs) + else: f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b52ab7f05c..7768caeaf0 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -443,13 +443,7 @@ def get_session(): session = default_session else: if _SESSION is None: - if not os.environ.get('OMP_NUM_THREADS'): - config = config_pb2.ConfigProto(allow_soft_placement=True) - else: - num_thread = int(os.environ.get('OMP_NUM_THREADS')) - config = config_pb2.ConfigProto( - intra_op_parallelism_threads=num_thread, allow_soft_placement=True) - _SESSION = session_module.Session(config=config) + _SESSION = session_module.Session(config=get_default_session_config()) session = _SESSION if not _MANUAL_VAR_INIT: with session.graph.as_default(): @@ -468,6 +462,16 @@ def set_session(session): _SESSION = session +def get_default_session_config(): + if not os.environ.get('OMP_NUM_THREADS'): + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + num_thread = int(os.environ.get('OMP_NUM_THREADS')) + config = config_pb2.ConfigProto( + intra_op_parallelism_threads=num_thread, allow_soft_placement=True) + return config + + # DEVICE MANIPULATION diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index fcb073322c..c1c4970025 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.client import session as session_module from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import backend +from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib @@ -46,7 +47,7 @@ def set_weights(distribution_strategy, dist_model, weights): assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) weights = weights[num_param:] - backend.get_session().run(assign_ops) + K.get_session().run(assign_ops) def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, @@ -269,3 +270,20 @@ def validate_all_tensor_shapes(x, x_values): if x_shape != x_values[i].get_shape().as_list(): raise ValueError('Input tensor shapes do not match for distributed tensor' ' inputs {}'.format(x)) + + +def configure_and_create_session(distribution_strategy): + """Configure session config and create a session with it.""" + # TODO(priyag): Throw error if a session already exists. + session_config = K.get_default_session_config() + distribution_strategy.configure(session_config) + + if distribution_strategy.__class__.__name__ == 'TPUStrategy': + # TODO(priyag): Remove this workaround when Distributed Coordinator is + # integrated with keras and we can create a session from there. + master = distribution_strategy._tpu_cluster_resolver.master() # pylint: disable=protected-access + session = session_module.Session(config=session_config, target=master) + else: + session = session_module.Session(config=session_config) + + K.set_session(session) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index cd74e36e68..f8c23ed124 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1355,7 +1355,9 @@ class Network(base_layer.Layer): ``` """ if not self._is_graph_network: - raise NotImplementedError + raise NotImplementedError( + 'Currently `save` requires model to be a graph network. Consider ' + 'using `save_weights`, in order to save the weights of the model.') from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top save_model(self, filepath, overwrite, include_optimizer) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 85d25411b4..966b446f22 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -405,20 +405,9 @@ class Model(Network): # Set DistributionStrategy specific parameters. self._distribution_strategy = distribute if self._distribution_strategy is not None: - self._grouped_model = self._compile_distributed_model( + self._grouped_model = None + distributed_training_utils.configure_and_create_session( self._distribution_strategy) - with self._distribution_strategy.scope(): - first_replicated_model = self._distribution_strategy.unwrap( - self._grouped_model)[0] - # If the specified metrics in `compile` are stateful, raise an error - # since we currently don't support stateful metrics. - if first_replicated_model.stateful_metric_names: - raise NotImplementedError('Stateful metrics are not supported with ' - 'DistributionStrategy.') - - # We initialize the callback model with the first replicated model. - self._replicated_model = DistributedCallbackModel(first_replicated_model) - self._replicated_model.set_original_model(self) if not self.built: # Model is not compilable because it does not know its number of inputs # and outputs, nor their shapes and names. We will compile after the first @@ -636,6 +625,12 @@ class Model(Network): skip_target_indices=skip_target_indices, sample_weights=self.sample_weights) + # If using distribution strategy and stateful_metrics, raise an error + # since we currently don't support stateful metrics. + if self._distribution_strategy is not None and self.stateful_metric_names: + raise NotImplementedError('Stateful metrics are not supported with ' + 'DistributionStrategy.') + # Prepare gradient updates and state updates. self.total_loss = total_loss @@ -652,19 +647,6 @@ class Model(Network): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights - def _compile_distributed_model(self, distribution_strategy): - # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the - # model? - def _clone_model_per_tower(model): - new_model = training_distributed.clone_and_build_model(model) - return new_model - - with distribution_strategy.scope(): - # Create a copy of this model on each of the devices. - grouped_models = distribution_strategy.call_for_each_tower( - _clone_model_per_tower, self) - return grouped_models - def _check_trainable_weights_consistency(self): """Check trainable weights count consistency. @@ -790,10 +772,7 @@ class Model(Network): Fraction of the training data to be used as validation data. Returns: - A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. - If the model's input and targets are symbolic, these lists are empty - (since the model takes no user-provided data, instead the data comes - from the symbolic inputs/targets). + Iterator for reading the dataset `x`. Raises: ValueError: In case of invalid user-provided data. @@ -828,30 +807,7 @@ class Model(Network): training_utils.validate_iterator_input(x, y, sample_weight, validation_split) - # x an y may be PerDevice objects with an input and output tensor - # corresponding to each device. For example, x could be - # PerDevice:{device: get_next tensor,...}. - next_element = iterator.get_next() - - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide model inputs as a list or tuple of 2 ' - 'elements: input and target pair. ' - 'Received %s' % next_element) - x, y = next_element - # Validate that all the elements in x and y are of the same type and shape. - # We can then pass the first element of x and y to `_standardize_weights` - # below and be confident of the output. We need to reopen the scope since - # we unwrap values when we validate x and y. - with self._distribution_strategy.scope(): - x_values, y_values = distributed_training_utils.\ - validate_distributed_dataset_inputs(self._distribution_strategy, x, y) - - _, _, sample_weights = self._standardize_weights(x_values, - y_values, - sample_weight, - class_weight, - batch_size) - return x, y, sample_weights + return iterator def _standardize_user_data(self, x, @@ -916,7 +872,7 @@ class Model(Network): RuntimeError: If the model was never compiled. """ if self._distribution_strategy: - return self._distribution_standardize_user_data( + iterator = self._distribution_standardize_user_data( x, y, sample_weight=sample_weight, @@ -926,6 +882,7 @@ class Model(Network): steps_name=steps_name, steps=steps, validation_split=validation_split) + return iterator, None, None if isinstance(x, dataset_ops.Dataset): if context.executing_eagerly(): @@ -982,6 +939,7 @@ class Model(Network): def _standardize_weights(self, x, y, sample_weight=None, class_weight=None, batch_size=None,): + # TODO(sourabhbajaj): Split input validation from weight standardization. if sample_weight is not None and class_weight is not None: logging.warning( 'Received both a `sample_weight` and `class_weight` argument. ' @@ -1566,12 +1524,11 @@ class Model(Network): validation_steps=validation_steps) elif self._distribution_strategy: return training_distributed.fit_loop( - self, x, y, + self, x, epochs=epochs, verbose=verbose, callbacks=callbacks, - val_inputs=val_x, - val_targets=val_y, + val_iterator=val_x, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) @@ -1677,8 +1634,7 @@ class Model(Network): elif self._distribution_strategy: return training_distributed.test_loop( self, - inputs=x, - targets=y, + iterator=x, verbose=verbose, steps=steps) else: @@ -2188,6 +2144,13 @@ class Model(Network): return self.callback_model return self + def _make_callback_model(self): + first_replicated_model = self._distribution_strategy.unwrap( + self._grouped_model)[0] + # We initialize the callback model with the first replicated model. + self._replicated_model = DistributedCallbackModel(first_replicated_model) + self._replicated_model.set_original_model(self) + class DistributedCallbackModel(Model): """Model that is used for callbacks with DistributionStrategy.""" @@ -2225,6 +2188,6 @@ class DistributedCallbackModel(Model): # Whitelisted atttributes of the model that can be accessed by the user # during a callback. if item not in ['_setattr_tracking']: - logging.warning('You are accessing attribute ' + item + 'of the' - 'DistributedCallbackModel that may not have been set' + logging.warning('You are accessing attribute ' + item + 'of the ' + 'DistributedCallbackModel that may not have been set ' 'correctly.') diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 85f1d6299f..a7bb1f8177 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -30,13 +30,11 @@ from tensorflow.python.platform import tf_logging as logging def fit_loop( model, - inputs, - targets, + iterator, epochs=100, verbose=1, callbacks=None, - val_inputs=None, - val_targets=None, + val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): @@ -44,13 +42,11 @@ def fit_loop( Arguments: model: Keras Model instance. - inputs: List of input arrays. - targets: List of target arrays. + iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training - val_inputs: List of input arrays. - val_targets: List of target arrays. + val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) @@ -67,6 +63,10 @@ def fit_loop( ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy + + clone_model_on_towers( + model, current_strategy, make_callback_model=True) + def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, @@ -74,6 +74,7 @@ def fit_loop( model.train_function.updates_op, model.train_function.session_kwargs) + inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. @@ -169,8 +170,7 @@ def fit_loop( if do_validation: val_outs = test_loop( model, - val_inputs, - val_targets, + val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): @@ -192,13 +192,12 @@ def fit_loop( return model.history -def test_loop(model, inputs, targets, verbose=0, steps=None): +def test_loop(model, iterator, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. Arguments: model: Keras Model instance. - inputs: List of input arrays. - targets: List of target arrays. + iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring predictions finished. @@ -211,6 +210,9 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): the display labels for the scalar outputs. """ current_strategy = model._distribution_strategy + + clone_model_on_towers(model, current_strategy) + def _per_device_test_function(model): model._make_test_function() return (model.test_function.inputs, @@ -218,6 +220,7 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): model.test_function.updates_op, model.test_function.session_kwargs) + inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( @@ -284,12 +287,12 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): return outs -def predict_loop(model, inputs, verbose=0, steps=None): +def predict_loop(model, iterator, verbose=0, steps=None): """Abstract method to loop over some data in batches. Arguments: model: Keras Model instance. - inputs: list of tensors to be fed to `f`. + iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. @@ -301,6 +304,9 @@ def predict_loop(model, inputs, verbose=0, steps=None): (if the model has multiple outputs). """ current_strategy = model._distribution_strategy + + clone_model_on_towers(model, current_strategy) + def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, @@ -308,6 +314,7 @@ def predict_loop(model, inputs, verbose=0, steps=None): model.predict_function.updates_op, model.predict_function.session_kwargs) + inputs, _ = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( @@ -366,7 +373,7 @@ def predict_loop(model, inputs, verbose=0, steps=None): ] -def clone_and_build_model(model): +def _clone_and_build_model(model): """Clone and build the given keras_model.""" # We need to set the import here since we run into a circular dependency # error. @@ -390,6 +397,16 @@ def clone_and_build_model(model): return cloned_model +def clone_model_on_towers(model, strategy, make_callback_model=False): + """Create a cloned model on each tower, unless already created.""" + if not model._grouped_model: + with strategy.scope(): + model._grouped_model = strategy.call_for_each_tower( + _clone_and_build_model, model) + if make_callback_model: + model._make_callback_model() + + def _aggregate_metrics_across_towers(num_devices, out_labels, outs): """Aggregate metrics values across all towers. @@ -419,3 +436,25 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs): merged_output.append(m) current_index += num_devices return merged_output + + +def _get_input_from_iterator(iterator, model): + """Get elements from the iterator and verify the input shape and type.""" + next_element = iterator.get_next() + # TODO(anjalisridhar): Support predict input correctly as it will not contain + # targets, only inputs. + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide model inputs as a list or tuple of 2 ' + 'elements: input and target pair. ' + 'Received %s' % next_element) + + x, y = next_element + # Validate that all the elements in x and y are of the same type and shape. + # We can then pass the first element of x and y to `_standardize_weights` + # below and be confident of the output. + x_values, y_values = distributed_training_utils.\ + validate_distributed_dataset_inputs(model._distribution_strategy, x, y) + # TODO(sourabhbajaj): Add support for sample weights in distribution + # strategy. + model._standardize_weights(x_values, y_values) + return x, y diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 05f998d0d2..680d0c97cc 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -116,7 +116,7 @@ class AssertEqualTest(test.TestCase): check_ops.assert_equal(static_big, static_small, message="fail") def test_raises_when_greater_dynamic(self): - with self.test_session(): + with self.cached_session(): small = array_ops.placeholder(dtypes.int32, name="small") big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies( @@ -194,7 +194,7 @@ First 2 elements of y: check_ops.assert_equal(static_big, static_small, message="fail") def test_raises_when_less_dynamic(self): - with self.test_session(): + with self.cached_session(): small = array_ops.placeholder(dtypes.int32, name="small") big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies([check_ops.assert_equal(small, big)]): @@ -271,30 +271,28 @@ class AssertNoneEqualTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_but_non_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 1, 1], name="small") - big = constant_op.constant([10, 10], name="big") - # The exception in eager and non-eager mode is different because - # eager mode relies on shape check done as part of the C++ op, while - # graph mode does shape checks when creating the `Operation` instance. - with self.assertRaisesRegexp( - (ValueError, errors.InvalidArgumentError), - (r"Incompatible shapes: \[3\] vs. \[2\]|" - r"Dimensions must be equal, but are 3 and 2")): - with ops.control_dependencies( - [check_ops.assert_none_equal(small, big)]): - out = array_ops.identity(small) - self.evaluate(out) + small = constant_op.constant([1, 1, 1], name="small") + big = constant_op.constant([10, 10], name="big") + # The exception in eager and non-eager mode is different because + # eager mode relies on shape check done as part of the C++ op, while + # graph mode does shape checks when creating the `Operation` instance. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + (r"Incompatible shapes: \[3\] vs. \[2\]|" + r"Dimensions must be equal, but are 3 and 2")): + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): - with self.test_session(): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies( - [check_ops.assert_none_equal(larry, curly)]): - out = array_ops.identity(larry) - self.evaluate(out) + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies( + [check_ops.assert_none_equal(larry, curly)]): + out = array_ops.identity(larry) + self.evaluate(out) def test_returns_none_with_eager(self): with context.eager_mode(): @@ -905,7 +903,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -923,7 +921,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -940,7 +938,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -957,7 +955,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -974,7 +972,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 2 with ops.control_dependencies( @@ -989,7 +987,7 @@ class AssertRankTest(test.TestCase): check_ops.assert_rank(tensor, np.array([], dtype=np.int32)) def test_raises_if_rank_is_not_scalar_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( [1, 2], dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor") @@ -1006,7 +1004,7 @@ class AssertRankTest(test.TestCase): check_ops.assert_rank(tensor, .5) def test_raises_if_rank_is_not_integer_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( [1, 2], dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") @@ -1029,7 +1027,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank0)) def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") with ops.control_dependencies([ check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): @@ -1045,7 +1043,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank0)) def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): with ops.control_dependencies([ @@ -1061,7 +1059,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank1)) def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): with ops.control_dependencies([ @@ -1079,7 +1077,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank1)) def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") with ops.control_dependencies([ check_ops.assert_rank_in(tensor_rank1, (0, 2))]): @@ -1098,7 +1096,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor, desired_ranks) def test_raises_if_rank_is_not_scalar_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( (42, 43), dtype=dtypes.float32, name="my_tensor") desired_ranks = ( @@ -1120,7 +1118,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor, (1, .5,)) def test_raises_if_rank_is_not_integer_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( (42, 43), dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") @@ -1143,7 +1141,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -1160,7 +1158,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -1176,7 +1174,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -1192,7 +1190,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -1209,7 +1207,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 2 with ops.control_dependencies( diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 9ad77a54cb..26d013bccb 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -62,59 +62,50 @@ class BernoulliTest(test.TestCase): def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(p, self.evaluate(dist.probs)) + self.assertAllClose(p, self.evaluate(dist.probs)) @test_util.run_in_graph_and_eager_modes def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) - with self.test_session(): - self.assertAllClose(logits, self.evaluate(dist.logits)) + self.assertAllClose(logits, self.evaluate(dist.logits)) if not special: return - with self.test_session(): - self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) + self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) p = [0.01, 0.99, 0.42] dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) + self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) @test_util.run_in_graph_and_eager_modes def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: - with self.test_session(): - with self.assertRaisesOpError("probs has components greater than 1"): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - self.evaluate(dist.probs) + with self.assertRaisesOpError("probs has components greater than 1"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + self.evaluate(dist.probs) invalid_ps = [-0.01, -3.] for p in invalid_ps: - with self.test_session(): - with self.assertRaisesOpError("Condition x >= 0"): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - self.evaluate(dist.probs) + with self.assertRaisesOpError("Condition x >= 0"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + self.evaluate(dist.probs) valid_ps = [0.0, 0.5, 1.0] for p in valid_ps: - with self.test_session(): - dist = bernoulli.Bernoulli(probs=p) - self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail + dist = bernoulli.Bernoulli(probs=p) + self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail @test_util.run_in_graph_and_eager_modes def testShapes(self): - with self.test_session(): - for batch_shape in ([], [1], [2, 3, 4]): - dist = make_bernoulli(batch_shape) - self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) - self.assertAllEqual(batch_shape, - self.evaluate(dist.batch_shape_tensor())) - self.assertAllEqual([], dist.event_shape.as_list()) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + for batch_shape in ([], [1], [2, 3, 4]): + dist = make_bernoulli(batch_shape) + self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) + self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) + self.assertAllEqual([], dist.event_shape.as_list()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) @test_util.run_in_graph_and_eager_modes def testDtype(self): @@ -137,31 +128,29 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) - with self.test_session(): - # pylint: disable=bad-continuation - xs = [ - 0, - [1], - [1, 0], - [[1, 0]], - [[1, 0], [1, 1]], - ] - expected_pmfs = [ - [[0.8, 0.6], [0.7, 0.4]], - [[0.2, 0.4], [0.3, 0.6]], - [[0.2, 0.6], [0.3, 0.4]], - [[0.2, 0.6], [0.3, 0.4]], - [[0.2, 0.6], [0.3, 0.6]], - ] - # pylint: enable=bad-continuation - - for x, expected_pmf in zip(xs, expected_pmfs): - self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) - self.assertAllClose( - self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) + # pylint: disable=bad-continuation + xs = [ + 0, + [1], + [1, 0], + [[1, 0]], + [[1, 0], [1, 1]], + ] + expected_pmfs = [ + [[0.8, 0.6], [0.7, 0.4]], + [[0.2, 0.4], [0.3, 0.6]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.6]], + ] + # pylint: enable=bad-continuation + + for x, expected_pmf in zip(xs, expected_pmfs): + self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) + self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) def testPmfCorrectBroadcastDynamicShape(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtype=dtypes.float32) dist = bernoulli.Bernoulli(probs=p) event1 = [1, 0, 1] @@ -178,12 +167,11 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testPmfInvalid(self): p = [0.1, 0.2, 0.7] - with self.test_session(): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - with self.assertRaisesOpError("must be non-negative."): - self.evaluate(dist.prob([1, 1, -1])) - with self.assertRaisesOpError("Elements cannot exceed 1."): - self.evaluate(dist.prob([2, 0, 1])) + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + with self.assertRaisesOpError("must be non-negative."): + self.evaluate(dist.prob([1, 1, -1])) + with self.assertRaisesOpError("Elements cannot exceed 1."): + self.evaluate(dist.prob([2, 0, 1])) @test_util.run_in_graph_and_eager_modes def testPmfWithP(self): @@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase): self._testPmf(logits=special.logit(p)) def testBroadcasting(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.float32) dist = bernoulli.Bernoulli(probs=p) self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) @@ -208,70 +196,63 @@ class BernoulliTest(test.TestCase): })) def testPmfShapes(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) dist = bernoulli.Bernoulli(probs=p) self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) - with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) - with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual((), dist.log_prob(1).get_shape()) self.assertEqual((1), dist.log_prob([1]).get_shape()) self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) - with self.test_session(): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape()) @test_util.run_in_graph_and_eager_modes def testBoundaryConditions(self): - with self.test_session(): - dist = bernoulli.Bernoulli(probs=1.0) - self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) - self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) + dist = bernoulli.Bernoulli(probs=1.0) + self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) + self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) @test_util.run_in_graph_and_eager_modes def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) + self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) @test_util.run_in_graph_and_eager_modes def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(probs=p, validate_args=False) - with self.test_session(): - self.assertAllClose( - self.evaluate(dist.entropy()), - [[entropy(0.1), entropy(0.7)], [entropy(0.2), - entropy(0.6)]]) + self.assertAllClose( + self.evaluate(dist.entropy()), + [[entropy(0.1), entropy(0.7)], [entropy(0.2), + entropy(0.6)]]) @test_util.run_in_graph_and_eager_modes def testSampleN(self): - with self.test_session(): - p = [0.2, 0.6] - dist = bernoulli.Bernoulli(probs=p) - n = 100000 - samples = dist.sample(n) - samples.set_shape([n, 2]) - self.assertEqual(samples.dtype, dtypes.int32) - sample_values = self.evaluate(samples) - self.assertTrue(np.all(sample_values >= 0)) - self.assertTrue(np.all(sample_values <= 1)) - # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / - # n). This means that the tolerance is very sensitive to the value of p - # as well as n. - self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) - self.assertEqual(set([0, 1]), set(sample_values.flatten())) - # In this test we're just interested in verifying there isn't a crash - # owing to mismatched types. b/30940152 - dist = bernoulli.Bernoulli(np.log([.2, .4])) - self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + p = [0.2, 0.6] + dist = bernoulli.Bernoulli(probs=p) + n = 100000 + samples = dist.sample(n) + samples.set_shape([n, 2]) + self.assertEqual(samples.dtype, dtypes.int32) + sample_values = self.evaluate(samples) + self.assertTrue(np.all(sample_values >= 0)) + self.assertTrue(np.all(sample_values <= 1)) + # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / + # n). This means that the tolerance is very sensitive to the value of p + # as well as n. + self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) + self.assertEqual(set([0, 1]), set(sample_values.flatten())) + # In this test we're just interested in verifying there isn't a crash + # owing to mismatched types. b/30940152 + dist = bernoulli.Bernoulli(np.log([.2, .4])) + self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) @test_util.run_in_graph_and_eager_modes def testNotReparameterized(self): @@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase): self.assertIsNone(grad_p) def testSampleActsLikeSampleN(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = [0.2, 0.6] dist = bernoulli.Bernoulli(probs=p) n = 1000 @@ -299,27 +280,24 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMean(self): - with self.test_session(): - p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) - dist = bernoulli.Bernoulli(probs=p) - self.assertAllEqual(self.evaluate(dist.mean()), p) + p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) + dist = bernoulli.Bernoulli(probs=p) + self.assertAllEqual(self.evaluate(dist.mean()), p) @test_util.run_in_graph_and_eager_modes def testVarianceAndStd(self): var = lambda p: p * (1. - p) - with self.test_session(): - p = [[0.2, 0.7], [0.5, 0.4]] - dist = bernoulli.Bernoulli(probs=p) - self.assertAllClose( - self.evaluate(dist.variance()), - np.array( - [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) - self.assertAllClose( - self.evaluate(dist.stddev()), - np.array( - [[np.sqrt(var(0.2)), np.sqrt(var(0.7))], - [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], - dtype=np.float32)) + p = [[0.2, 0.7], [0.5, 0.4]] + dist = bernoulli.Bernoulli(probs=p) + self.assertAllClose( + self.evaluate(dist.variance()), + np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], + dtype=np.float32)) + self.assertAllClose( + self.evaluate(dist.stddev()), + np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))], + [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], + dtype=np.float32)) @test_util.run_in_graph_and_eager_modes def testBernoulliBernoulliKL(self): diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 36f3ffc333..d580a415dd 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -20,7 +20,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed @@ -51,237 +50,215 @@ stats = try_import("scipy.stats") class BetaTest(test.TestCase): def testSimpleShapes(self): - with self.test_session(): - a = np.random.rand(3) - b = np.random.rand(3) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) + a = np.random.rand(3) + b = np.random.rand(3) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) def testComplexShapes(self): - with self.test_session(): - a = np.random.rand(3, 2, 2) - b = np.random.rand(3, 2, 2) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual( - tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) def testComplexShapesBroadcast(self): - with self.test_session(): - a = np.random.rand(3, 2, 2) - b = np.random.rand(2, 2) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual( - tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) def testAlphaProperty(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.concentration1.get_shape()) - self.assertAllClose(a, self.evaluate(dist.concentration1)) + dist = beta_lib.Beta(a, b) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, self.evaluate(dist.concentration1)) def testBetaProperty(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.concentration0.get_shape()) - self.assertAllClose(b, self.evaluate(dist.concentration0)) + dist = beta_lib.Beta(a, b) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, self.evaluate(dist.concentration0)) def testPdfXProper(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b, validate_args=True) - self.evaluate(dist.prob([.1, .3, .6])) - self.evaluate(dist.prob([.2, .3, .5])) - # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): - self.evaluate(dist.prob([-1., 0.1, 0.5])) - with self.assertRaisesOpError("sample must be positive"): - self.evaluate(dist.prob([0., 0.1, 0.5])) - with self.assertRaisesOpError("sample must be less than `1`"): - self.evaluate(dist.prob([.1, .2, 1.2])) - with self.assertRaisesOpError("sample must be less than `1`"): - self.evaluate(dist.prob([.1, .2, 1.0])) + dist = beta_lib.Beta(a, b, validate_args=True) + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) + # Either condition can trigger. + with self.assertRaisesOpError("sample must be positive"): + self.evaluate(dist.prob([-1., 0.1, 0.5])) + with self.assertRaisesOpError("sample must be positive"): + self.evaluate(dist.prob([0., 0.1, 0.5])) + with self.assertRaisesOpError("sample must be less than `1`"): + self.evaluate(dist.prob([.1, .2, 1.2])) + with self.assertRaisesOpError("sample must be less than `1`"): + self.evaluate(dist.prob([.1, .2, 1.0])) def testPdfTwoBatches(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [.5, .5] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2,), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2,), pdf.get_shape()) def testPdfTwoBatchesNontrivialX(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [.3, .7] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) - self.assertEqual((2,), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) + self.assertEqual((2,), pdf.get_shape()) def testPdfUniformZeroBatch(self): - with self.test_session(): - # This is equivalent to a uniform distribution - a = 1. - b = 1. - x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1.] * 5, self.evaluate(pdf)) - self.assertEqual((5,), pdf.get_shape()) + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1.] * 5, self.evaluate(pdf)) + self.assertEqual((5,), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - a = [[1., 2]] - b = [[1., 2]] - x = [[.5, .5], [.3, .7]] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [[.5, .5], [.2, .8]] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - a = [[1., 2], [2., 3]] - b = [[1., 2], [2., 3]] - x = [[.5, .5]] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - a = [[1., 2], [2., 3]] - b = [[1., 2], [2., 3]] - x = [.5, .5] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testBetaMean(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.mean().get_shape(), (3,)) - if not stats: - return - expected_mean = stats.beta.mean(a, b) - self.assertAllClose(expected_mean, self.evaluate(dist.mean())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = stats.beta.mean(a, b) + self.assertAllClose(expected_mean, self.evaluate(dist.mean())) def testBetaVariance(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.variance().get_shape(), (3,)) - if not stats: - return - expected_variance = stats.beta.var(a, b) - self.assertAllClose(expected_variance, self.evaluate(dist.variance())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = stats.beta.var(a, b) + self.assertAllClose(expected_variance, self.evaluate(dist.variance())) def testBetaMode(self): - with session.Session(): - a = np.array([1.1, 2, 3]) - b = np.array([2., 4, 1.2]) - expected_mode = (a - 1) / (a + b - 2) - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.mode().get_shape(), (3,)) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = (a - 1) / (a + b - 2) + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaModeInvalid(self): - with session.Session(): - a = np.array([1., 2, 3]) - b = np.array([2., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dist.mode()) - - a = np.array([2., 2, 3]) - b = np.array([1., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dist.mode()) + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dist.mode()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dist.mode()) def testBetaModeEnableAllowNanStats(self): - with session.Session(): - a = np.array([1., 2, 3]) - b = np.array([2., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=True) + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=True) - expected_mode = (a - 1) / (a + b - 2) - expected_mode[0] = np.nan - self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + expected_mode = (a - 1) / (a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) - a = np.array([2., 2, 3]) - b = np.array([1., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=True) + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=True) - expected_mode = (a - 1) / (a + b - 2) - expected_mode[0] = np.nan - self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + expected_mode = (a - 1) / (a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaEntropy(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.beta.entropy(a, b) - self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.beta.entropy(a, b) + self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) def testBetaSample(self): - with self.test_session(): - a = 1. - b = 2. - beta = beta_lib.Beta(a, b) - n = constant_op.constant(100000) - samples = beta.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000,)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - self.assertLess( - stats.kstest( - # Beta is a univariate distribution. - sample_values, - stats.beta(a=1., b=2.).cdf)[0], - 0.01) - # The standard error of the sample mean is 1 / (sqrt(18 * n)) - self.assertAllClose( - sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) - self.assertAllClose( - np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) + a = 1. + b = 2. + beta = beta_lib.Beta(a, b) + n = constant_op.constant(100000) + samples = beta.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values, + stats.beta(a=1., b=2.).cdf)[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + self.assertAllClose( + sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) + self.assertAllClose( + np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) def testBetaFullyReparameterized(self): a = constant_op.constant(1.0) @@ -297,78 +274,71 @@ class BetaTest(test.TestCase): # Test that sampling with the same seed twice gives the same results. def testBetaSampleMultipleTimes(self): - with self.test_session(): - a_val = 1. - b_val = 2. - n_val = 100 + a_val = 1. + b_val = 2. + n_val = 100 - random_seed.set_random_seed(654321) - beta1 = beta_lib.Beta(concentration1=a_val, - concentration0=b_val, - name="beta1") - samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) + random_seed.set_random_seed(654321) + beta1 = beta_lib.Beta( + concentration1=a_val, concentration0=b_val, name="beta1") + samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) - random_seed.set_random_seed(654321) - beta2 = beta_lib.Beta(concentration1=a_val, - concentration0=b_val, - name="beta2") - samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) + random_seed.set_random_seed(654321) + beta2 = beta_lib.Beta( + concentration1=a_val, concentration0=b_val, name="beta2") + samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) - self.assertAllClose(samples1, samples2) + self.assertAllClose(samples1, samples2) def testBetaSampleMultidimensional(self): - with self.test_session(): - a = np.random.rand(3, 2, 2).astype(np.float32) - b = np.random.rand(3, 2, 2).astype(np.float32) - beta = beta_lib.Beta(a, b) - n = constant_op.constant(100000) - samples = beta.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - self.assertAllClose( - sample_values[:, 1, :].mean(axis=0), - stats.beta.mean(a, b)[1, :], - atol=1e-1) + a = np.random.rand(3, 2, 2).astype(np.float32) + b = np.random.rand(3, 2, 2).astype(np.float32) + beta = beta_lib.Beta(a, b) + n = constant_op.constant(100000) + samples = beta.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + stats.beta.mean(a, b)[1, :], + atol=1e-1) def testBetaCdf(self): - with self.test_session(): - shape = (30, 40, 50) - for dt in (np.float32, np.float64): - a = 10. * np.random.random(shape).astype(dt) - b = 10. * np.random.random(shape).astype(dt) - x = np.random.random(shape).astype(dt) - actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) - if not stats: - return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) def testBetaLogCdf(self): - with self.test_session(): - shape = (30, 40, 50) - for dt in (np.float32, np.float64): - a = 10. * np.random.random(shape).astype(dt) - b = 10. * np.random.random(shape).astype(dt) - x = np.random.random(shape).astype(dt) - actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) - if not stats: - return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) def testBetaWithSoftplusConcentration(self): - with self.test_session(): - a, b = -4.2, -9.1 - dist = beta_lib.BetaWithSoftplusConcentration(a, b) - self.assertAllClose( - self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) - self.assertAllClose( - self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) + a, b = -4.2, -9.1 + dist = beta_lib.BetaWithSoftplusConcentration(a, b) + self.assertAllClose( + self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) def testBetaBetaKL(self): for shape in [(10,), (4, 5)]: diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py index 8b11556330..e20f59f48a 100644 --- a/tensorflow/python/kernel_tests/distributions/bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py @@ -36,11 +36,10 @@ class BaseBijectorTest(test.TestCase): """Tests properties of the Bijector base-class.""" def testIsAbstract(self): - with self.test_session(): - with self.assertRaisesRegexp(TypeError, - ("Can't instantiate abstract class Bijector " - "with abstract methods __init__")): - bijector.Bijector() # pylint: disable=abstract-class-instantiated + with self.assertRaisesRegexp(TypeError, + ("Can't instantiate abstract class Bijector " + "with abstract methods __init__")): + bijector.Bijector() # pylint: disable=abstract-class-instantiated def testDefaults(self): class _BareBonesBijector(bijector.Bijector): @@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase): def testBijectorDynamicEventNdims(self): bij = BrokenBijector(validate_args=True) event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Expected scalar"): bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ event_ndims: (1, 2)}) @@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase): event_ndims = array_ops.placeholder(dtype=np.int32, shape=[]) bij = ExpOnlyJacobian(forward_min_event_ndims=1) bij.inverse_log_det_jacobian(x, event_ndims=event_ndims) - with self.test_session() as sess: + with self.cached_session() as sess: ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims), feed_dict={event_ndims: 1}) self.assertAllClose(-np.log(x_), ildj) diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 67ed0447ed..cace5b3ba2 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -49,115 +49,102 @@ stats = try_import("scipy.stats") class DirichletTest(test.TestCase): def testSimpleShapes(self): - with self.test_session(): - alpha = np.random.rand(3) - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) + alpha = np.random.rand(3) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) def testComplexShapes(self): - with self.test_session(): - alpha = np.random.rand(3, 2, 2) - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) + alpha = np.random.rand(3, 2, 2) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) def testConcentrationProperty(self): alpha = [[1., 2, 3]] - with self.test_session(): - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual([1, 3], dist.concentration.get_shape()) - self.assertAllClose(alpha, self.evaluate(dist.concentration)) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual([1, 3], dist.concentration.get_shape()) + self.assertAllClose(alpha, self.evaluate(dist.concentration)) def testPdfXProper(self): alpha = [[1., 2, 3]] - with self.test_session(): - dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) - self.evaluate(dist.prob([.1, .3, .6])) - self.evaluate(dist.prob([.2, .3, .5])) - # Either condition can trigger. - with self.assertRaisesOpError("samples must be positive"): - self.evaluate(dist.prob([-1., 1.5, 0.5])) - with self.assertRaisesOpError("samples must be positive"): - self.evaluate(dist.prob([0., .1, .9])) - with self.assertRaisesOpError( - "sample last-dimension must sum to `1`"): - self.evaluate(dist.prob([.1, .2, .8])) + dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) + # Either condition can trigger. + with self.assertRaisesOpError("samples must be positive"): + self.evaluate(dist.prob([-1., 1.5, 0.5])) + with self.assertRaisesOpError("samples must be positive"): + self.evaluate(dist.prob([0., .1, .9])) + with self.assertRaisesOpError("sample last-dimension must sum to `1`"): + self.evaluate(dist.prob([.1, .2, .8])) def testPdfZeroBatches(self): - with self.test_session(): - alpha = [1., 2] - x = [.5, .5] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose(1., self.evaluate(pdf)) - self.assertEqual((), pdf.get_shape()) + alpha = [1., 2] + x = [.5, .5] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose(1., self.evaluate(pdf)) + self.assertEqual((), pdf.get_shape()) def testPdfZeroBatchesNontrivialX(self): - with self.test_session(): - alpha = [1., 2] - x = [.3, .7] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose(7. / 5, self.evaluate(pdf)) - self.assertEqual((), pdf.get_shape()) + alpha = [1., 2] + x = [.3, .7] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose(7. / 5, self.evaluate(pdf)) + self.assertEqual((), pdf.get_shape()) def testPdfUniformZeroBatches(self): - with self.test_session(): - # Corresponds to a uniform distribution - alpha = [1., 1, 1] - x = [[.2, .5, .3], [.3, .4, .3]] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose([2., 2.], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + # Corresponds to a uniform distribution + alpha = [1., 1, 1] + x = [[.2, .5, .3], [.3, .4, .3]] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose([2., 2.], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - alpha = [[1., 2]] - x = [[.5, .5], [.3, .7]] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - alpha = [1., 2] - x = [[.5, .5], [.2, .8]] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - alpha = [[1., 2], [2., 3]] - x = [[.5, .5]] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - alpha = [[1., 2], [2., 3]] - x = [.5, .5] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testMean(self): - with self.test_session(): - alpha = [1., 2, 3] - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.mean().get_shape(), [3]) - if not stats: - return - expected_mean = stats.dirichlet.mean(alpha) - self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) + alpha = [1., 2, 3] + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mean().get_shape(), [3]) + if not stats: + return + expected_mean = stats.dirichlet.mean(alpha) + self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) def testCovarianceFromSampling(self): alpha = np.array([[1., 2, 3], @@ -197,73 +184,66 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testVariance(self): - with self.test_session(): - alpha = [1., 2, 3] - denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) - if not stats: - return - expected_covariance = np.diag(stats.dirichlet.var(alpha)) - expected_covariance += [[0., -2, -3], [-2, 0, -6], - [-3, -6, 0]] / denominator - self.assertAllClose( - self.evaluate(dirichlet.covariance()), expected_covariance) + alpha = [1., 2, 3] + denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) + if not stats: + return + expected_covariance = np.diag(stats.dirichlet.var(alpha)) + expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0] + ] / denominator + self.assertAllClose( + self.evaluate(dirichlet.covariance()), expected_covariance) def testMode(self): - with self.test_session(): - alpha = np.array([1.1, 2, 3]) - expected_mode = (alpha - 1) / (np.sum(alpha) - 3) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) + alpha = np.array([1.1, 2, 3]) + expected_mode = (alpha - 1) / (np.sum(alpha) - 3) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mode().get_shape(), [3]) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testModeInvalid(self): - with self.test_session(): - alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, - allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dirichlet.mode()) + alpha = np.array([1., 2, 3]) + dirichlet = dirichlet_lib.Dirichlet( + concentration=alpha, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dirichlet.mode()) def testModeEnableAllowNanStats(self): - with self.test_session(): - alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, - allow_nan_stats=True) - expected_mode = np.zeros_like(alpha) + np.nan + alpha = np.array([1., 2, 3]) + dirichlet = dirichlet_lib.Dirichlet( + concentration=alpha, allow_nan_stats=True) + expected_mode = np.zeros_like(alpha) + np.nan - self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) + self.assertEqual(dirichlet.mode().get_shape(), [3]) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testEntropy(self): - with self.test_session(): - alpha = [1., 2, 3] - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.entropy().get_shape(), ()) - if not stats: - return - expected_entropy = stats.dirichlet.entropy(alpha) - self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) + alpha = [1., 2, 3] + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.entropy().get_shape(), ()) + if not stats: + return + expected_entropy = stats.dirichlet.entropy(alpha) + self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) def testSample(self): - with self.test_session(): - alpha = [1., 2] - dirichlet = dirichlet_lib.Dirichlet(alpha) - n = constant_op.constant(100000) - samples = dirichlet.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertTrue(np.all(sample_values > 0.0)) - if not stats: - return - self.assertLess( - stats.kstest( - # Beta is a univariate distribution. - sample_values[:, 0], - stats.beta( - a=1., b=2.).cdf)[0], - 0.01) + alpha = [1., 2] + dirichlet = dirichlet_lib.Dirichlet(alpha) + n = constant_op.constant(100000) + samples = dirichlet.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertTrue(np.all(sample_values > 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values[:, 0], + stats.beta(a=1., b=2.).cdf)[0], + 0.01) def testDirichletFullyReparameterized(self): alpha = constant_op.constant([1.0, 2.0, 3.0]) diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index 850da3e969..27d1291912 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -22,7 +22,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util @@ -48,121 +47,108 @@ stats = try_import("scipy.stats") class ExponentialTest(test.TestCase): def testExponentialLogPDF(self): - with session.Session(): - batch_size = 6 - lam = constant_op.constant([2.0] * batch_size) - lam_v = 2.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(rate=lam) + batch_size = 6 + lam = constant_op.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + exponential = exponential_lib.Exponential(rate=lam) - log_pdf = exponential.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) + log_pdf = exponential.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) - pdf = exponential.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) + pdf = exponential.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + if not stats: + return + expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testExponentialCDF(self): - with session.Session(): - batch_size = 6 - lam = constant_op.constant([2.0] * batch_size) - lam_v = 2.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + lam = constant_op.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(rate=lam) + exponential = exponential_lib.Exponential(rate=lam) - cdf = exponential.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) + cdf = exponential.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + if not stats: + return + expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testExponentialMean(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.mean().get_shape(), (3,)) - if not stats: - return - expected_mean = stats.expon.mean(scale=1 / lam_v) - self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = stats.expon.mean(scale=1 / lam_v) + self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) def testExponentialVariance(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.variance().get_shape(), (3,)) - if not stats: - return - expected_variance = stats.expon.var(scale=1 / lam_v) - self.assertAllClose( - self.evaluate(exponential.variance()), expected_variance) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = stats.expon.var(scale=1 / lam_v) + self.assertAllClose( + self.evaluate(exponential.variance()), expected_variance) def testExponentialEntropy(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.expon.entropy(scale=1 / lam_v) - self.assertAllClose( - self.evaluate(exponential.entropy()), expected_entropy) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.expon.entropy(scale=1 / lam_v) + self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy) def testExponentialSample(self): - with self.test_session(): - lam = constant_op.constant([3.0, 4.0]) - lam_v = [3.0, 4.0] - n = constant_op.constant(100000) - exponential = exponential_lib.Exponential(rate=lam) - - samples = exponential.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - for i in range(2): - self.assertLess( - stats.kstest( - sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) + lam = constant_op.constant([3.0, 4.0]) + lam_v = [3.0, 4.0] + n = constant_op.constant(100000) + exponential = exponential_lib.Exponential(rate=lam) + + samples = exponential.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + for i in range(2): + self.assertLess( + stats.kstest(sample_values[:, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) def testExponentialSampleMultiDimensional(self): - with self.test_session(): - batch_size = 2 - lam_v = [3.0, 22.0] - lam = constant_op.constant([lam_v] * batch_size) + batch_size = 2 + lam_v = [3.0, 22.0] + lam = constant_op.constant([lam_v] * batch_size) - exponential = exponential_lib.Exponential(rate=lam) + exponential = exponential_lib.Exponential(rate=lam) + + n = 100000 + samples = exponential.sample(n, seed=138) + self.assertEqual(samples.get_shape(), (n, batch_size, 2)) + + sample_values = self.evaluate(samples) - n = 100000 - samples = exponential.sample(n, seed=138) - self.assertEqual(samples.get_shape(), (n, batch_size, 2)) - - sample_values = self.evaluate(samples) - - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - for i in range(2): - self.assertLess( - stats.kstest( - sample_values[:, 0, i], - stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) - self.assertLess( - stats.kstest( - sample_values[:, 1, i], - stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + for i in range(2): + self.assertLess( + stats.kstest(sample_values[:, 0, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) + self.assertLess( + stats.kstest(sample_values[:, 1, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) def testFullyReparameterized(self): lam = constant_op.constant([0.1, 1.0]) @@ -174,11 +160,10 @@ class ExponentialTest(test.TestCase): self.assertIsNotNone(grad_lam) def testExponentialWithSoftplusRate(self): - with self.test_session(): - lam = [-2.2, -3.4] - exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) - self.assertAllClose( - self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) + lam = [-2.2, -3.4] + exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) + self.assertAllClose( + self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py index 297e20264c..4eff40b029 100644 --- a/tensorflow/python/kernel_tests/distributions/gamma_test.py +++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py @@ -50,221 +50,203 @@ stats = try_import("scipy.stats") class GammaTest(test.TestCase): def testGammaShape(self): - with self.test_session(): - alpha = constant_op.constant([3.0] * 5) - beta = constant_op.constant(11.0) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + alpha = constant_op.constant([3.0] * 5) + beta = constant_op.constant(11.0) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) - self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) - self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) + self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) + self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) def testGammaLogPDF(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([2.0] * batch_size) - beta = constant_op.constant([3.0] * batch_size) - alpha_v = 2.0 - beta_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) - pdf = gamma.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([2.0] * batch_size) + beta = constant_op.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + pdf = gamma.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([[2.0, 4.0]] * batch_size) - beta = constant_op.constant([[3.0, 4.0]] * batch_size) - alpha_v = np.array([2.0, 4.0]) - beta_v = np.array([3.0, 4.0]) - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = gamma.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([[2.0, 4.0]] * batch_size) + beta = constant_op.constant([[3.0, 4.0]] * batch_size) + alpha_v = np.array([2.0, 4.0]) + beta_v = np.array([3.0, 4.0]) + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = gamma.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensionalBroadcasting(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([[2.0, 4.0]] * batch_size) - beta = constant_op.constant(3.0) - alpha_v = np.array([2.0, 4.0]) - beta_v = 3.0 - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = gamma.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([[2.0, 4.0]] * batch_size) + beta = constant_op.constant(3.0) + alpha_v = np.array([2.0, 4.0]) + beta_v = 3.0 + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = gamma.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) - def testGammaCDF(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([2.0] * batch_size) - beta = constant_op.constant([3.0] * batch_size) - alpha_v = 2.0 - beta_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - cdf = gamma.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + def testGammaCDF(self): + batch_size = 6 + alpha = constant_op.constant([2.0] * batch_size) + beta = constant_op.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + cdf = gamma.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testGammaMean(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.mean().get_shape(), (3,)) - if not stats: - return - expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.mean()), expected_means) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.mean().get_shape(), (3,)) + if not stats: + return + expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.mean()), expected_means) def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): - with self.test_session(): - alpha_v = np.array([5.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - expected_modes = (alpha_v - 1) / beta_v - self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) + alpha_v = np.array([5.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + expected_modes = (alpha_v - 1) / beta_v + self.assertEqual(gamma.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): - with self.test_session(): - # Mode will not be defined for the first entry. - alpha_v = np.array([0.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(gamma.mode()) + # Mode will not be defined for the first entry. + alpha_v = np.array([0.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(gamma.mode()) def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): - with self.test_session(): - # Mode will not be defined for the first entry. - alpha_v = np.array([0.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - allow_nan_stats=True) - expected_modes = (alpha_v - 1) / beta_v - expected_modes[0] = np.nan - self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) + # Mode will not be defined for the first entry. + alpha_v = np.array([0.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, allow_nan_stats=True) + expected_modes = (alpha_v - 1) / beta_v + expected_modes[0] = np.nan + self.assertEqual(gamma.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaVariance(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.variance().get_shape(), (3,)) - if not stats: - return - expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.variance().get_shape(), (3,)) + if not stats: + return + expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) def testGammaStd(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.stddev().get_shape(), (3,)) - if not stats: - return - expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) - self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.stddev().get_shape(), (3,)) + if not stats: + return + expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) + self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) def testGammaEntropy(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) def testGammaSampleSmallAlpha(self): - with self.test_session(): - alpha_v = 0.05 - beta_v = 1.0 - alpha = constant_op.constant(alpha_v) - beta = constant_op.constant(beta_v) - n = 100000 - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.gamma.mean( - alpha_v, scale=1 / beta_v), - atol=.01) - self.assertAllClose( - sample_values.var(), - stats.gamma.var(alpha_v, scale=1 / beta_v), - atol=.15) + alpha_v = 0.05 + beta_v = 1.0 + alpha = constant_op.constant(alpha_v) + beta = constant_op.constant(beta_v) + n = 100000 + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.gamma.mean(alpha_v, scale=1 / beta_v), + atol=.01) + self.assertAllClose( + sample_values.var(), + stats.gamma.var(alpha_v, scale=1 / beta_v), + atol=.15) def testGammaSample(self): - with self.test_session(): - alpha_v = 4.0 - beta_v = 3.0 - alpha = constant_op.constant(alpha_v) - beta = constant_op.constant(beta_v) - n = 100000 - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.gamma.mean( - alpha_v, scale=1 / beta_v), - atol=.01) - self.assertAllClose( - sample_values.var(), - stats.gamma.var(alpha_v, scale=1 / beta_v), - atol=.15) + alpha_v = 4.0 + beta_v = 3.0 + alpha = constant_op.constant(alpha_v) + beta = constant_op.constant(beta_v) + n = 100000 + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.gamma.mean(alpha_v, scale=1 / beta_v), + atol=.01) + self.assertAllClose( + sample_values.var(), + stats.gamma.var(alpha_v, scale=1 / beta_v), + atol=.15) def testGammaFullyReparameterized(self): alpha = constant_op.constant(4.0) @@ -279,37 +261,37 @@ class GammaTest(test.TestCase): self.assertIsNotNone(grad_beta) def testGammaSampleMultiDimensional(self): - with self.test_session(): - alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 - beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - n = 10000 - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n, 10, 100)) - self.assertEqual(sample_values.shape, (n, 10, 100)) - zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 - alpha_bc = alpha_v + zeros - beta_bc = beta_v + zeros - if not stats: - return - self.assertAllClose( - sample_values.mean(axis=0), - stats.gamma.mean( - alpha_bc, scale=1 / beta_bc), - atol=0., rtol=.05) - self.assertAllClose( - sample_values.var(axis=0), - stats.gamma.var(alpha_bc, scale=1 / beta_bc), - atol=10.0, rtol=0.) - fails = 0 - trials = 0 - for ai, a in enumerate(np.reshape(alpha_v, [-1])): - for bi, b in enumerate(np.reshape(beta_v, [-1])): - s = sample_values[:, bi, ai] - trials += 1 - fails += 0 if self._kstest(a, b, s) else 1 - self.assertLess(fails, trials * 0.03) + alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 + beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + n = 10000 + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n, 10, 100)) + self.assertEqual(sample_values.shape, (n, 10, 100)) + zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 + alpha_bc = alpha_v + zeros + beta_bc = beta_v + zeros + if not stats: + return + self.assertAllClose( + sample_values.mean(axis=0), + stats.gamma.mean(alpha_bc, scale=1 / beta_bc), + atol=0., + rtol=.05) + self.assertAllClose( + sample_values.var(axis=0), + stats.gamma.var(alpha_bc, scale=1 / beta_bc), + atol=10.0, + rtol=0.) + fails = 0 + trials = 0 + for ai, a in enumerate(np.reshape(alpha_v, [-1])): + for bi, b in enumerate(np.reshape(beta_v, [-1])): + s = sample_values[:, bi, ai] + trials += 1 + fails += 0 if self._kstest(a, b, s) else 1 + self.assertLess(fails, trials * 0.03) def _kstest(self, alpha, beta, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -320,30 +302,29 @@ class GammaTest(test.TestCase): return ks < 0.02 def testGammaPdfOfSampleMultiDims(self): - with self.test_session(): - gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) - num = 50000 - samples = gamma.sample(num, seed=137) - pdfs = gamma.prob(samples) - sample_vals, pdf_vals = self.evaluate([samples, pdfs]) - self.assertEqual(samples.get_shape(), (num, 2, 2)) - self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) - if not stats: - return - self.assertAllClose( - stats.gamma.mean( - [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])), - sample_vals.mean(axis=0), - atol=.1) - self.assertAllClose( - stats.gamma.var([[7., 11.], [7., 11.]], - scale=1 / np.array([[5., 5.], [6., 6.]])), - sample_vals.var(axis=0), - atol=.1) + gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) + num = 50000 + samples = gamma.sample(num, seed=137) + pdfs = gamma.prob(samples) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) + self.assertEqual(samples.get_shape(), (num, 2, 2)) + self.assertEqual(pdfs.get_shape(), (num, 2, 2)) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + if not stats: + return + self.assertAllClose( + stats.gamma.mean([[7., 11.], [7., 11.]], + scale=1 / np.array([[5., 5.], [6., 6.]])), + sample_vals.mean(axis=0), + atol=.1) + self.assertAllClose( + stats.gamma.var([[7., 11.], [7., 11.]], + scale=1 / np.array([[5., 5.], [6., 6.]])), + sample_vals.var(axis=0), + atol=.1) def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): s_p = zip(sample_vals, pdf_vals) @@ -356,32 +337,29 @@ class GammaTest(test.TestCase): self.assertNear(1., total, err=err) def testGammaNonPositiveInitializationParamsRaises(self): - with self.test_session(): - alpha_v = constant_op.constant(0.0, name="alpha") - beta_v = constant_op.constant(1.0, name="beta") - with self.assertRaisesOpError("x > 0"): - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - self.evaluate(gamma.mean()) - alpha_v = constant_op.constant(1.0, name="alpha") - beta_v = constant_op.constant(0.0, name="beta") - with self.assertRaisesOpError("x > 0"): - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - self.evaluate(gamma.mean()) + alpha_v = constant_op.constant(0.0, name="alpha") + beta_v = constant_op.constant(1.0, name="beta") + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + self.evaluate(gamma.mean()) + alpha_v = constant_op.constant(1.0, name="alpha") + beta_v = constant_op.constant(0.0, name="beta") + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + self.evaluate(gamma.mean()) def testGammaWithSoftplusConcentrationRate(self): - with self.test_session(): - alpha_v = constant_op.constant([0.0, -2.1], name="alpha") - beta_v = constant_op.constant([1.0, -3.6], name="beta") - gamma = gamma_lib.GammaWithSoftplusConcentrationRate( - concentration=alpha_v, rate=beta_v) - self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), - self.evaluate(gamma.concentration)) - self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), - self.evaluate(gamma.rate)) + alpha_v = constant_op.constant([0.0, -2.1], name="alpha") + beta_v = constant_op.constant([1.0, -3.6], name="beta") + gamma = gamma_lib.GammaWithSoftplusConcentrationRate( + concentration=alpha_v, rate=beta_v) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(alpha_v)), + self.evaluate(gamma.concentration)) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate)) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -391,15 +369,14 @@ class GammaTest(test.TestCase): beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. - with self.test_session(): - g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) - g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) - x = g0.sample(int(1e4), seed=0) - kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) - kl_actual = kullback_leibler.kl_divergence(g0, g1) - - # Execute graph. - [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) + g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) + g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) + x = g0.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) + kl_actual = kullback_leibler.kl_divergence(g0, g1) + + # Execute graph. + [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py index 24b243f647..630c2cb424 100644 --- a/tensorflow/python/kernel_tests/distributions/laplace_test.py +++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py @@ -21,7 +21,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape @@ -49,212 +48,198 @@ stats = try_import("scipy.stats") class LaplaceTest(test.TestCase): def testLaplaceShape(self): - with self.test_session(): - loc = constant_op.constant([3.0] * 5) - scale = constant_op.constant(11.0) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + loc = constant_op.constant([3.0] * 5) + scale = constant_op.constant(11.0) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) - self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) - self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) + self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) + self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) def testLaplaceLogPDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - pdf = laplace.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + pdf = laplace.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testLaplaceLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([[2.0, 4.0]] * batch_size) - scale = constant_op.constant([[3.0, 4.0]] * batch_size) - loc_v = np.array([2.0, 4.0]) - scale_v = np.array([3.0, 4.0]) - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - - pdf = laplace.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + loc = constant_op.constant([[2.0, 4.0]] * batch_size) + scale = constant_op.constant([[3.0, 4.0]] * batch_size) + loc_v = np.array([2.0, 4.0]) + scale_v = np.array([3.0, 4.0]) + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + + pdf = laplace.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testLaplaceLogPDFMultidimensionalBroadcasting(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([[2.0, 4.0]] * batch_size) - scale = constant_op.constant(3.0) - loc_v = np.array([2.0, 4.0]) - scale_v = 3.0 - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - - pdf = laplace.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + loc = constant_op.constant([[2.0, 4.0]] * batch_size) + scale = constant_op.constant(3.0) + loc_v = np.array([2.0, 4.0]) + scale_v = 3.0 + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + + pdf = laplace.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testLaplaceCDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - cdf = laplace.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + cdf = laplace.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogCDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - cdf = laplace.log_cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + cdf = laplace.log_cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogSurvivalFunction(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - sf = laplace.log_survival_function(x) - self.assertEqual(sf.get_shape(), (6,)) - if not stats: - return - expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(sf), expected_sf) + sf = laplace.log_survival_function(x) + self.assertEqual(sf.get_shape(), (6,)) + if not stats: + return + expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(sf), expected_sf) def testLaplaceMean(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.mean().get_shape(), (3,)) - if not stats: - return - expected_means = stats.laplace.mean(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.mean()), expected_means) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.mean().get_shape(), (3,)) + if not stats: + return + expected_means = stats.laplace.mean(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.mean()), expected_means) def testLaplaceMode(self): - with self.test_session(): - loc_v = np.array([0.5, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(laplace.mode()), loc_v) + loc_v = np.array([0.5, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(laplace.mode()), loc_v) def testLaplaceVariance(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.variance().get_shape(), (3,)) - if not stats: - return - expected_variances = stats.laplace.var(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.variance().get_shape(), (3,)) + if not stats: + return + expected_variances = stats.laplace.var(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) def testLaplaceStd(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.stddev().get_shape(), (3,)) - if not stats: - return - expected_stddev = stats.laplace.std(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.stddev().get_shape(), (3,)) + if not stats: + return + expected_stddev = stats.laplace.std(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) def testLaplaceEntropy(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) def testLaplaceSample(self): - with session.Session(): - loc_v = 4.0 - scale_v = 3.0 - loc = constant_op.constant(loc_v) - scale = constant_op.constant(scale_v) - n = 100000 - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - samples = laplace.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.laplace.mean( - loc_v, scale=scale_v), - rtol=0.05, - atol=0.) - self.assertAllClose( - sample_values.var(), - stats.laplace.var(loc_v, scale=scale_v), - rtol=0.05, - atol=0.) - self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) + loc_v = 4.0 + scale_v = 3.0 + loc = constant_op.constant(loc_v) + scale = constant_op.constant(scale_v) + n = 100000 + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + samples = laplace.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.laplace.mean(loc_v, scale=scale_v), + rtol=0.05, + atol=0.) + self.assertAllClose( + sample_values.var(), + stats.laplace.var(loc_v, scale=scale_v), + rtol=0.05, + atol=0.) + self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) def testLaplaceFullyReparameterized(self): loc = constant_op.constant(4.0) @@ -269,39 +254,37 @@ class LaplaceTest(test.TestCase): self.assertIsNotNone(grad_scale) def testLaplaceSampleMultiDimensional(self): - with session.Session(): - loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 - scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - n = 10000 - samples = laplace.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n, 10, 100)) - self.assertEqual(sample_values.shape, (n, 10, 100)) - zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 - loc_bc = loc_v + zeros - scale_bc = scale_v + zeros - if not stats: - return - self.assertAllClose( - sample_values.mean(axis=0), - stats.laplace.mean( - loc_bc, scale=scale_bc), - rtol=0.35, - atol=0.) - self.assertAllClose( - sample_values.var(axis=0), - stats.laplace.var(loc_bc, scale=scale_bc), - rtol=0.10, - atol=0.) - fails = 0 - trials = 0 - for ai, a in enumerate(np.reshape(loc_v, [-1])): - for bi, b in enumerate(np.reshape(scale_v, [-1])): - s = sample_values[:, bi, ai] - trials += 1 - fails += 0 if self._kstest(a, b, s) else 1 - self.assertLess(fails, trials * 0.03) + loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 + scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + n = 10000 + samples = laplace.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n, 10, 100)) + self.assertEqual(sample_values.shape, (n, 10, 100)) + zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 + loc_bc = loc_v + zeros + scale_bc = scale_v + zeros + if not stats: + return + self.assertAllClose( + sample_values.mean(axis=0), + stats.laplace.mean(loc_bc, scale=scale_bc), + rtol=0.35, + atol=0.) + self.assertAllClose( + sample_values.var(axis=0), + stats.laplace.var(loc_bc, scale=scale_bc), + rtol=0.10, + atol=0.) + fails = 0 + trials = 0 + for ai, a in enumerate(np.reshape(loc_v, [-1])): + for bi, b in enumerate(np.reshape(scale_v, [-1])): + s = sample_values[:, bi, ai] + trials += 1 + fails += 0 if self._kstest(a, b, s) else 1 + self.assertLess(fails, trials * 0.03) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -349,30 +332,26 @@ class LaplaceTest(test.TestCase): self.assertNear(1., total, err=err) def testLaplaceNonPositiveInitializationParamsRaises(self): - with self.test_session(): - loc_v = constant_op.constant(0.0, name="loc") - scale_v = constant_op.constant(-1.0, name="scale") - with self.assertRaisesOpError( - "Condition x > 0 did not hold element-wise"): - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - self.evaluate(laplace.mean()) - loc_v = constant_op.constant(1.0, name="loc") - scale_v = constant_op.constant(0.0, name="scale") - with self.assertRaisesOpError( - "Condition x > 0 did not hold element-wise"): - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - self.evaluate(laplace.mean()) + loc_v = constant_op.constant(0.0, name="loc") + scale_v = constant_op.constant(-1.0, name="scale") + with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) + loc_v = constant_op.constant(1.0, name="loc") + scale_v = constant_op.constant(0.0, name="scale") + with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) def testLaplaceWithSoftplusScale(self): - with self.test_session(): - loc_v = constant_op.constant([0.0, 1.0], name="loc") - scale_v = constant_op.constant([-1.0, 2.0], name="scale") - laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) - self.assertAllClose( - self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) - self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) + loc_v = constant_op.constant([0.0, 1.0], name="loc") + scale_v = constant_op.constant([-1.0, 2.0], name="scale") + laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) + self.assertAllClose( + self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) + self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index 7ff48c0c10..de73a40b23 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -61,16 +61,15 @@ class NormalTest(test.TestCase): self.assertAllEqual(all_true, is_finite) def _testParamShapes(self, sample_shape, expected): - with self.test_session(): - param_shapes = normal_lib.Normal.param_shapes(sample_shape) - mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] - self.assertAllEqual(expected, self.evaluate(mu_shape)) - self.assertAllEqual(expected, self.evaluate(sigma_shape)) - mu = array_ops.zeros(mu_shape) - sigma = array_ops.ones(sigma_shape) - self.assertAllEqual( - expected, - self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) + param_shapes = normal_lib.Normal.param_shapes(sample_shape) + mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] + self.assertAllEqual(expected, self.evaluate(mu_shape)) + self.assertAllEqual(expected, self.evaluate(sigma_shape)) + mu = array_ops.zeros(mu_shape) + sigma = array_ops.ones(sigma_shape) + self.assertAllEqual( + expected, + self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) def _testParamStaticShapes(self, sample_shape, expected): param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) @@ -91,156 +90,150 @@ class NormalTest(test.TestCase): self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalWithSoftplusScale(self): - with self.test_session(): - mu = array_ops.zeros((10, 3)) - rho = array_ops.ones((10, 3)) * -2. - normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) - self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) - self.assertAllEqual( - self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) + mu = array_ops.zeros((10, 3)) + rho = array_ops.ones((10, 3)) * -2. + normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) + self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) @test_util.run_in_graph_and_eager_modes def testNormalLogPDF(self): - with self.test_session(): - batch_size = 6 - mu = constant_op.constant([3.0] * batch_size) - sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) - x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) - normal = normal_lib.Normal(loc=mu, scale=sigma) - - log_pdf = normal.log_prob(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(log_pdf).shape) - self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + batch_size = 6 + mu = constant_op.constant([3.0] * batch_size) + sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + normal = normal_lib.Normal(loc=mu, scale=sigma) - pdf = normal.prob(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(pdf).shape) - self.assertAllEqual(normal.batch_shape, pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) - - if not stats: - return - expected_log_pdf = stats.norm(self.evaluate(mu), - self.evaluate(sigma)).logpdf(x) - self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) - self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) + log_pdf = normal.log_prob(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) + self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + + pdf = normal.prob(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(pdf).shape) + self.assertAllEqual(normal.batch_shape, pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) + + if not stats: + return + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) + self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) + self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testNormalLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - mu = constant_op.constant([[3.0, -3.0]] * batch_size) - sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * - batch_size) - x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T - normal = normal_lib.Normal(loc=mu, scale=sigma) - - log_pdf = normal.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(log_pdf).shape) - self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) - - pdf = normal.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) - self.assertAllEqual(normal.batch_shape, pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, pdf_values.shape) + batch_size = 6 + mu = constant_op.constant([[3.0, -3.0]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + normal = normal_lib.Normal(loc=mu, scale=sigma) - if not stats: - return - expected_log_pdf = stats.norm(self.evaluate(mu), - self.evaluate(sigma)).logpdf(x) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + log_pdf = normal.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) + self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + + pdf = normal.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) + self.assertAllEqual(normal.batch_shape, pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, pdf_values.shape) + + if not stats: + return + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) @test_util.run_in_graph_and_eager_modes def testNormalCDF(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) - cdf = normal.cdf(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(cdf).shape) - self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) - if not stats: - return - expected_cdf = stats.norm(mu, sigma).cdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) + normal = normal_lib.Normal(loc=mu, scale=sigma) + cdf = normal.cdf(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) + self.assertAllEqual(normal.batch_shape, cdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).cdf(x) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) @test_util.run_in_graph_and_eager_modes def testNormalSurvivalFunction(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - sf = normal.survival_function(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(sf).shape) - self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) - if not stats: - return - expected_sf = stats.norm(mu, sigma).sf(x) - self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) + sf = normal.survival_function(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) + self.assertAllEqual(normal.batch_shape, sf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) + if not stats: + return + expected_sf = stats.norm(mu, sigma).sf(x) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) @test_util.run_in_graph_and_eager_modes def testNormalLogCDF(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - cdf = normal.log_cdf(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(cdf).shape) - self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) + cdf = normal.log_cdf(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) + self.assertAllEqual(normal.batch_shape, cdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) - if not stats: - return - expected_cdf = stats.norm(mu, sigma).logcdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).logcdf(x) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: @@ -256,7 +249,7 @@ class NormalTest(test.TestCase): ]: value = func(x) grads = gradients_impl.gradients(value, [mu, sigma]) - with self.test_session(graph=g): + with self.session(graph=g): variables.global_variables_initializer().run() self.assertAllFinite(value) self.assertAllFinite(grads[0]) @@ -264,112 +257,106 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalLogSurvivalFunction(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - sf = normal.log_survival_function(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(sf).shape) - self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) + sf = normal.log_survival_function(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) + self.assertAllEqual(normal.batch_shape, sf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) - if not stats: - return - expected_sf = stats.norm(mu, sigma).logsf(x) - self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) + if not stats: + return + expected_sf = stats.norm(mu, sigma).logsf(x) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) @test_util.run_in_graph_and_eager_modes def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. - with self.test_session(): - mu_v = 2.34 - sigma_v = 4.56 - normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) - - entropy = normal.entropy() - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(entropy).shape) - self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - # scipy.stats.norm cannot deal with these shapes. - if not stats: - return - expected_entropy = stats.norm(mu_v, sigma_v).entropy() - self.assertAllClose(expected_entropy, self.evaluate(entropy)) + mu_v = 2.34 + sigma_v = 4.56 + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) + + entropy = normal.entropy() + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) + self.assertAllEqual(normal.batch_shape, entropy.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) + # scipy.stats.norm cannot deal with these shapes. + if not stats: + return + expected_entropy = stats.norm(mu_v, sigma_v).entropy() + self.assertAllClose(expected_entropy, self.evaluate(entropy)) @test_util.run_in_graph_and_eager_modes def testNormalEntropy(self): - with self.test_session(): - mu_v = np.array([1.0, 1.0, 1.0]) - sigma_v = np.array([[1.0, 2.0, 3.0]]).T - normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) - - # scipy.stats.norm cannot deal with these shapes. - sigma_broadcast = mu_v * sigma_v - expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast** - 2) - entropy = normal.entropy() - np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(entropy).shape) - self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - - @test_util.run_in_graph_and_eager_modes + mu_v = np.array([1.0, 1.0, 1.0]) + sigma_v = np.array([[1.0, 2.0, 3.0]]).T + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) + + # scipy.stats.norm cannot deal with these shapes. + sigma_broadcast = mu_v * sigma_v + expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2) + entropy = normal.entropy() + np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) + self.assertAllEqual(normal.batch_shape, entropy.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalMeanAndMode(self): - with self.test_session(): - # Mu will be broadcast to [7, 7, 7]. - mu = [7.] - sigma = [11., 12., 13.] + # Mu will be broadcast to [7, 7, 7]. + mu = [7.] + sigma = [11., 12., 13.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.mean().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) + self.assertAllEqual((3,), normal.mean().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) - self.assertAllEqual((3,), normal.mode().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) + self.assertAllEqual((3,), normal.mode().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) @test_util.run_in_graph_and_eager_modes def testNormalQuantile(self): - with self.test_session(): - batch_size = 52 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) - # Quantile performs piecewise rational approximation so adding some - # special input values to make sure we hit all the pieces. - p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) + batch_size = 52 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) + # Quantile performs piecewise rational approximation so adding some + # special input values to make sure we hit all the pieces. + p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) - normal = normal_lib.Normal(loc=mu, scale=sigma) - x = normal.quantile(p) + normal = normal_lib.Normal(loc=mu, scale=sigma) + x = normal.quantile(p) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), x.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(x).shape) - self.assertAllEqual(normal.batch_shape, x.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), x.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(x).shape) + self.assertAllEqual(normal.batch_shape, x.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) - if not stats: - return - expected_x = stats.norm(mu, sigma).ppf(p) - self.assertAllClose(expected_x, self.evaluate(x), atol=0.) + if not stats: + return + expected_x = stats.norm(mu, sigma).ppf(p) + self.assertAllClose(expected_x, self.evaluate(x), atol=0.) def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): g = ops.Graph() @@ -385,7 +372,7 @@ class NormalTest(test.TestCase): value = dist.quantile(p) grads = gradients_impl.gradients(value, [mu, p]) - with self.test_session(graph=g): + with self.cached_session(graph=g): variables.global_variables_initializer().run() self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1]) @@ -398,61 +385,58 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalVariance(self): - with self.test_session(): - # sigma will be broadcast to [7, 7, 7] - mu = [1., 2., 3.] - sigma = [7.] + # sigma will be broadcast to [7, 7, 7] + mu = [1., 2., 3.] + sigma = [7.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.variance().get_shape()) - self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) + self.assertAllEqual((3,), normal.variance().get_shape()) + self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) @test_util.run_in_graph_and_eager_modes def testNormalStandardDeviation(self): - with self.test_session(): - # sigma will be broadcast to [7, 7, 7] - mu = [1., 2., 3.] - sigma = [7.] + # sigma will be broadcast to [7, 7, 7] + mu = [1., 2., 3.] + sigma = [7.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.stddev().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) + self.assertAllEqual((3,), normal.stddev().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) @test_util.run_in_graph_and_eager_modes def testNormalSample(self): - with self.test_session(): - mu = constant_op.constant(3.0) - sigma = constant_op.constant(math.sqrt(3.0)) - mu_v = 3.0 - sigma_v = np.sqrt(3.0) - n = constant_op.constant(100000) - normal = normal_lib.Normal(loc=mu, scale=sigma) - samples = normal.sample(n) - sample_values = self.evaluate(samples) - # Note that the standard error for the sample mean is ~ sigma / sqrt(n). - # The sample variance similarly is dependent on sigma and n. - # Thus, the tolerances below are very sensitive to number of samples - # as well as the variances chosen. - self.assertEqual(sample_values.shape, (100000,)) - self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) - self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) - - expected_samples_shape = tensor_shape.TensorShape( - [self.evaluate(n)]).concatenate( - tensor_shape.TensorShape( - self.evaluate(normal.batch_shape_tensor()))) - - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) - - expected_samples_shape = ( - tensor_shape.TensorShape([self.evaluate(n)]).concatenate( - normal.batch_shape)) - - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) + mu = constant_op.constant(3.0) + sigma = constant_op.constant(math.sqrt(3.0)) + mu_v = 3.0 + sigma_v = np.sqrt(3.0) + n = constant_op.constant(100000) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(n) + sample_values = self.evaluate(samples) + # Note that the standard error for the sample mean is ~ sigma / sqrt(n). + # The sample variance similarly is dependent on sigma and n. + # Thus, the tolerances below are very sensitive to number of samples + # as well as the variances chosen. + self.assertEqual(sample_values.shape, (100000,)) + self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) + self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) + + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) + + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) + + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) + + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) def testNormalFullyReparameterized(self): mu = constant_op.constant(4.0) @@ -468,66 +452,63 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalSampleMultiDimensional(self): - with self.test_session(): - batch_size = 2 - mu = constant_op.constant([[3.0, -3.0]] * batch_size) - sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] * - batch_size) - mu_v = [3.0, -3.0] - sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] - n = constant_op.constant(100000) - normal = normal_lib.Normal(loc=mu, scale=sigma) - samples = normal.sample(n) - sample_values = self.evaluate(samples) - # Note that the standard error for the sample mean is ~ sigma / sqrt(n). - # The sample variance similarly is dependent on sigma and n. - # Thus, the tolerances below are very sensitive to number of samples - # as well as the variances chosen. - self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) - self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) - - expected_samples_shape = tensor_shape.TensorShape( - [self.evaluate(n)]).concatenate( - tensor_shape.TensorShape( - self.evaluate(normal.batch_shape_tensor()))) - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) - - expected_samples_shape = ( - tensor_shape.TensorShape([self.evaluate(n)]).concatenate( - normal.batch_shape)) - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) + batch_size = 2 + mu = constant_op.constant([[3.0, -3.0]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size) + mu_v = [3.0, -3.0] + sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] + n = constant_op.constant(100000) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(n) + sample_values = self.evaluate(samples) + # Note that the standard error for the sample mean is ~ sigma / sqrt(n). + # The sample variance similarly is dependent on sigma and n. + # Thus, the tolerances below are very sensitive to number of samples + # as well as the variances chosen. + self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) + self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) + + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) + + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) @test_util.run_in_graph_and_eager_modes def testNegativeSigmaFails(self): - with self.test_session(): - with self.assertRaisesOpError("Condition x > 0 did not hold"): - normal = normal_lib.Normal( - loc=[1.], scale=[-5.], validate_args=True, name="G") - self.evaluate(normal.mean()) + with self.assertRaisesOpError("Condition x > 0 did not hold"): + normal = normal_lib.Normal( + loc=[1.], scale=[-5.], validate_args=True, name="G") + self.evaluate(normal.mean()) @test_util.run_in_graph_and_eager_modes def testNormalShape(self): - with self.test_session(): - mu = constant_op.constant([-3.0] * 5) - sigma = constant_op.constant(11.0) - normal = normal_lib.Normal(loc=mu, scale=sigma) + mu = constant_op.constant([-3.0] * 5) + sigma = constant_op.constant(11.0) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) - self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) - self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) + self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) + self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) def testNormalShapeWithPlaceholders(self): mu = array_ops.placeholder(dtype=dtypes.float32) sigma = array_ops.placeholder(dtype=dtypes.float32) normal = normal_lib.Normal(loc=mu, scale=sigma) - with self.test_session() as sess: + with self.cached_session() as sess: # get_batch_shape should return an "<unknown>" tensor. self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) self.assertEqual(normal.event_shape, ()) diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index a634194ce5..cc43e12168 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -92,22 +92,21 @@ class NdtriTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNdtri(self): """Verifies that ndtri computation is correct.""" - with self.test_session(): - if not special: - return + if not special: + return - p = np.linspace(0., 1.0, 50).astype(np.float64) - # Quantile performs piecewise rational approximation so adding some - # special input values to make sure we hit all the pieces. - p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), - np.exp(-2), 1. - np.exp(-2))) - expected_x = special.ndtri(p) - x = special_math.ndtri(p) - self.assertAllClose(expected_x, self.evaluate(x), atol=0.) + p = np.linspace(0., 1.0, 50).astype(np.float64) + # Quantile performs piecewise rational approximation so adding some + # special input values to make sure we hit all the pieces. + p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2), + 1. - np.exp(-2))) + expected_x = special.ndtri(p) + x = special_math.ndtri(p) + self.assertAllClose(expected_x, self.evaluate(x), atol=0.) def testNdtriDynamicShape(self): """Verifies that ndtri computation is correct.""" - with self.test_session() as sess: + with self.cached_session() as sess: if not special: return @@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase): def _test_grad_accuracy(self, dtype, grid_spec, error_spec): raw_grid = _make_grid(dtype, grid_spec) grid = ops.convert_to_tensor(raw_grid) - with self.test_session(): + with self.cached_session(): fn = sm.log_ndtr if self._use_log else sm.ndtr # If there are N points in the grid, @@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest): class ErfInvTest(test.TestCase): def testErfInvValues(self): - with self.test_session(): + with self.cached_session(): if not special: return @@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase): self.assertAllClose(expected_x, x.eval(), atol=0.) def testErfInvIntegerInput(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): x = np.array([1, 2, 3]).astype(np.int32) @@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase): self.assertAllEqual(np.ones_like(x, dtype=np.bool), x) def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec): - with self.test_session(): + with self.cached_session(): grid = _make_grid(dtype, grid_spec) actual = sm.log_cdf_laplace(grid).eval() @@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase): ErrorSpec(rtol=0.05, atol=0)) def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self): - with self.test_session() as sess: + with self.cached_session() as sess: # On the lower branch, log_cdf_laplace(x) = x, so we know this will be # fine, but test to -200 anyways. grid = _make_grid( @@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase): self.assertFalse(np.any(grad_ == 0)) def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self): - with self.test_session() as sess: + with self.cached_session() as sess: # On the lower branch, log_cdf_laplace(x) = x, so we know this will be # fine, but test to -200 anyways. grid = _make_grid( diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py index 05590542ef..b34b538160 100644 --- a/tensorflow/python/kernel_tests/distributions/student_t_test.py +++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py @@ -50,100 +50,96 @@ stats = try_import("scipy.stats") class StudentTTest(test.TestCase): def testStudentPDFAndLogPDF(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([3.] * batch_size) - mu = constant_op.constant([7.] * batch_size) - sigma = constant_op.constant([8.] * batch_size) - df_v = 3. - mu_v = 7. - sigma_v = 8. - t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = student_t.StudentT(df, loc=mu, scale=-sigma) - - log_pdf = student.log_prob(t) - self.assertEquals(log_pdf.get_shape(), (6,)) - log_pdf_values = self.evaluate(log_pdf) - pdf = student.prob(t) - self.assertEquals(pdf.get_shape(), (6,)) - pdf_values = self.evaluate(pdf) - - if not stats: - return - - expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.log(expected_pdf), log_pdf_values) - self.assertAllClose(expected_pdf, pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + batch_size = 6 + df = constant_op.constant([3.] * batch_size) + mu = constant_op.constant([7.] * batch_size) + sigma = constant_op.constant([8.] * batch_size) + df_v = 3. + mu_v = 7. + sigma_v = 8. + t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) + student = student_t.StudentT(df, loc=mu, scale=-sigma) + + log_pdf = student.log_prob(t) + self.assertEquals(log_pdf.get_shape(), (6,)) + log_pdf_values = self.evaluate(log_pdf) + pdf = student.prob(t) + self.assertEquals(pdf.get_shape(), (6,)) + pdf_values = self.evaluate(pdf) + + if not stats: + return + + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) def testStudentLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([[1.5, 7.2]] * batch_size) - mu = constant_op.constant([[3., -3.]] * batch_size) - sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] * - batch_size) - df_v = np.array([1.5, 7.2]) - mu_v = np.array([3., -3.]) - sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) - t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T - student = student_t.StudentT(df, loc=mu, scale=sigma) - log_pdf = student.log_prob(t) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = student.prob(t) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - - if not stats: - return - expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.log(expected_pdf), log_pdf_values) - self.assertAllClose(expected_pdf, pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + batch_size = 6 + df = constant_op.constant([[1.5, 7.2]] * batch_size) + mu = constant_op.constant([[3., -3.]] * batch_size) + sigma = constant_op.constant( + [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size) + df_v = np.array([1.5, 7.2]) + mu_v = np.array([3., -3.]) + sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) + t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T + student = student_t.StudentT(df, loc=mu, scale=sigma) + log_pdf = student.log_prob(t) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = student.prob(t) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + + if not stats: + return + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) def testStudentCDFAndLogCDF(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([3.] * batch_size) - mu = constant_op.constant([7.] * batch_size) - sigma = constant_op.constant([-8.] * batch_size) - df_v = 3. - mu_v = 7. - sigma_v = 8. - t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = student_t.StudentT(df, loc=mu, scale=sigma) - - log_cdf = student.log_cdf(t) - self.assertEquals(log_cdf.get_shape(), (6,)) - log_cdf_values = self.evaluate(log_cdf) - cdf = student.cdf(t) - self.assertEquals(cdf.get_shape(), (6,)) - cdf_values = self.evaluate(cdf) - - if not stats: - return - expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) - self.assertAllClose( - np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) - self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) - self.assertAllClose( - np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) + batch_size = 6 + df = constant_op.constant([3.] * batch_size) + mu = constant_op.constant([7.] * batch_size) + sigma = constant_op.constant([-8.] * batch_size) + df_v = 3. + mu_v = 7. + sigma_v = 8. + t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) + student = student_t.StudentT(df, loc=mu, scale=sigma) + + log_cdf = student.log_cdf(t) + self.assertEquals(log_cdf.get_shape(), (6,)) + log_cdf_values = self.evaluate(log_cdf) + cdf = student.cdf(t) + self.assertEquals(cdf.get_shape(), (6,)) + cdf_values = self.evaluate(cdf) + + if not stats: + return + expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) + self.assertAllClose( + np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) + self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) + self.assertAllClose( + np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) def testStudentEntropy(self): df_v = np.array([[2., 3., 7.]]) # 1x3 mu_v = np.array([[1., -1, 0]]) # 1x3 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 - with self.test_session(): - student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) - ent = student.entropy() - ent_values = self.evaluate(ent) + student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) + ent = student.entropy() + ent_values = self.evaluate(ent) # Help scipy broadcast to 3x3 ones = np.array([[1, 1, 1]]) @@ -160,90 +156,81 @@ class StudentTTest(test.TestCase): self.assertAllClose(expected_entropy, ent_values) def testStudentSample(self): - with self.test_session(): - df = constant_op.constant(4.) - mu = constant_op.constant(3.) - sigma = constant_op.constant(-math.sqrt(10.)) - df_v = 4. - mu_v = 3. - sigma_v = np.sqrt(10.) - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - n_val = 200000 - self.assertEqual(sample_values.shape, (n_val,)) - self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) - self.assertAllClose( - sample_values.var(), - sigma_v**2 * df_v / (df_v - 2), - rtol=0.1, - atol=0) - self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) + df = constant_op.constant(4.) + mu = constant_op.constant(3.) + sigma = constant_op.constant(-math.sqrt(10.)) + df_v = 4. + mu_v = 3. + sigma_v = np.sqrt(10.) + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + n_val = 200000 + self.assertEqual(sample_values.shape, (n_val,)) + self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) + self.assertAllClose( + sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0) + self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) # Test that sampling with the same seed twice gives the same results. def testStudentSampleMultipleTimes(self): - with self.test_session(): - df = constant_op.constant(4.) - mu = constant_op.constant(3.) - sigma = constant_op.constant(math.sqrt(10.)) - n = constant_op.constant(100) + df = constant_op.constant(4.) + mu = constant_op.constant(3.) + sigma = constant_op.constant(math.sqrt(10.)) + n = constant_op.constant(100) - random_seed.set_random_seed(654321) - student = student_t.StudentT( - df=df, loc=mu, scale=sigma, name="student_t1") - samples1 = self.evaluate(student.sample(n, seed=123456)) + random_seed.set_random_seed(654321) + student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") + samples1 = self.evaluate(student.sample(n, seed=123456)) - random_seed.set_random_seed(654321) - student2 = student_t.StudentT( - df=df, loc=mu, scale=sigma, name="student_t2") - samples2 = self.evaluate(student2.sample(n, seed=123456)) + random_seed.set_random_seed(654321) + student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") + samples2 = self.evaluate(student2.sample(n, seed=123456)) - self.assertAllClose(samples1, samples2) + self.assertAllClose(samples1, samples2) def testStudentSampleSmallDfNoNan(self): - with self.test_session(): - df_v = [1e-1, 1e-5, 1e-10, 1e-20] - df = constant_op.constant(df_v) - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=1., scale=1.) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - n_val = 200000 - self.assertEqual(sample_values.shape, (n_val, 4)) - self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) + df_v = [1e-1, 1e-5, 1e-10, 1e-20] + df = constant_op.constant(df_v) + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=1., scale=1.) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + n_val = 200000 + self.assertEqual(sample_values.shape, (n_val, 4)) + self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) def testStudentSampleMultiDimensional(self): - with self.test_session(): - batch_size = 7 - df = constant_op.constant([[5., 7.]] * batch_size) - mu = constant_op.constant([[3., -3.]] * batch_size) - sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * - batch_size) - df_v = [5., 7.] - mu_v = [3., -3.] - sigma_v = [np.sqrt(10.), np.sqrt(15.)] - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) - self.assertAllClose( - sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) - self.assertAllClose( - sample_values[:, 0, 0].var(), - sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), - rtol=0.2, - atol=0) - self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) - self.assertAllClose( - sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) - self.assertAllClose( - sample_values[:, 0, 1].var(), - sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), - rtol=0.2, - atol=0) - self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) + batch_size = 7 + df = constant_op.constant([[5., 7.]] * batch_size) + mu = constant_op.constant([[3., -3.]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(10.), math.sqrt(15.)]] * batch_size) + df_v = [5., 7.] + mu_v = [3., -3.] + sigma_v = [np.sqrt(10.), np.sqrt(15.)] + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) + self.assertAllClose( + sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) + self.assertAllClose( + sample_values[:, 0, 0].var(), + sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), + rtol=0.2, + atol=0) + self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) + self.assertAllClose( + sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) + self.assertAllClose( + sample_values[:, 0, 1].var(), + sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), + rtol=0.2, + atol=0) + self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) def _checkKLApprox(self, df, mu, sigma, samples): n = samples.size @@ -325,114 +312,102 @@ class StudentTTest(test.TestCase): _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): - with self.test_session(): - mu = [1., 3.3, 4.4] - student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) - mean = self.evaluate(student.mean()) - self.assertAllClose([1., 3.3, 4.4], mean) + mu = [1., 3.3, 4.4] + student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) + mean = self.evaluate(student.mean()) + self.assertAllClose([1., 3.3, 4.4], mean) def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): - with self.test_session(): - mu = [1., 3.3, 4.4] - student = student_t.StudentT( - df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], - allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.mean()) + mu = [1., 3.3, 4.4] + student = student_t.StudentT( + df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.mean()) def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): - with self.test_session(): - mu = [-2, 0., 1., 3.3, 4.4] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT( - df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, - allow_nan_stats=True) - mean = self.evaluate(student.mean()) - self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) + mu = [-2, 0., 1., 3.3, 4.4] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT( + df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) + mean = self.evaluate(student.mean()) + self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): - with self.test_session(): - # df = 0.5 ==> undefined mean ==> undefined variance. - # df = 1.5 ==> infinite variance. - df = [0.5, 1.5, 3., 5., 7.] - mu = [-2, 0., 1., 3.3, 4.4] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT( - df=df, loc=mu, scale=sigma, allow_nan_stats=True) - var = self.evaluate(student.variance()) - ## scipy uses inf for variance when the mean is undefined. When mean is - # undefined we say variance is undefined as well. So test the first - # member of var, making sure it is NaN, then replace with inf and compare - # to scipy. - self.assertTrue(np.isnan(var[0])) - var[0] = np.inf - - if not stats: - return - expected_var = [ - stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_var, var) + # df = 0.5 ==> undefined mean ==> undefined variance. + # df = 1.5 ==> infinite variance. + df = [0.5, 1.5, 3., 5., 7.] + mu = [-2, 0., 1., 3.3, 4.4] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT( + df=df, loc=mu, scale=sigma, allow_nan_stats=True) + var = self.evaluate(student.variance()) + ## scipy uses inf for variance when the mean is undefined. When mean is + # undefined we say variance is undefined as well. So test the first + # member of var, making sure it is NaN, then replace with inf and compare + # to scipy. + self.assertTrue(np.isnan(var[0])) + var[0] = np.inf + + if not stats: + return + expected_var = [ + stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_var, var) def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( self): - with self.test_session(): - # df = 1.5 ==> infinite variance. - df = [1.5, 3., 5., 7.] - mu = [0., 1., 3.3, 4.4] - sigma = [4., 3., 2., 1.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - var = self.evaluate(student.variance()) + # df = 1.5 ==> infinite variance. + df = [1.5, 3., 5., 7.] + mu = [0., 1., 3.3, 4.4] + sigma = [4., 3., 2., 1.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + var = self.evaluate(student.variance()) - if not stats: - return - expected_var = [ - stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_var, var) + if not stats: + return + expected_var = [ + stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_var, var) def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): - with self.test_session(): - # df <= 1 ==> variance not defined - student = student_t.StudentT( - df=1., loc=0., scale=1., allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.variance()) + # df <= 1 ==> variance not defined + student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.variance()) - with self.test_session(): - # df <= 1 ==> variance not defined - student = student_t.StudentT( - df=0.5, loc=0., scale=1., allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.variance()) + # df <= 1 ==> variance not defined + student = student_t.StudentT( + df=0.5, loc=0., scale=1., allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.variance()) def testStd(self): - with self.test_session(): - # Defined for all batch members. - df = [3.5, 5., 3., 5., 7.] - mu = [-2.2] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - # Test broadcast of mu across shape of df/sigma - stddev = self.evaluate(student.stddev()) - mu *= len(df) + # Defined for all batch members. + df = [3.5, 5., 3., 5., 7.] + mu = [-2.2] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + # Test broadcast of mu across shape of df/sigma + stddev = self.evaluate(student.stddev()) + mu *= len(df) - if not stats: - return - expected_stddev = [ - stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_stddev, stddev) + if not stats: + return + expected_stddev = [ + stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_stddev, stddev) def testMode(self): - with self.test_session(): - df = [0.5, 1., 3] - mu = [-1, 0., 1] - sigma = [5., 4., 3.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - # Test broadcast of mu across shape of df/sigma - mode = self.evaluate(student.mode()) - self.assertAllClose([-1., 0, 1], mode) + df = [0.5, 1., 3] + mu = [-1, 0., 1] + sigma = [5., 4., 3.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + # Test broadcast of mu across shape of df/sigma + mode = self.evaluate(student.mode()) + self.assertAllClose([-1., 0, 1], mode) def testPdfOfSample(self): student = student_t.StudentT(df=3., loc=np.pi, scale=1.) @@ -510,25 +485,23 @@ class StudentTTest(test.TestCase): self.assertNear(1., total, err=err) def testNegativeDofFails(self): - with self.test_session(): - with self.assertRaisesOpError(r"Condition x > 0 did not hold"): - student = student_t.StudentT( - df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") - self.evaluate(student.mean()) + with self.assertRaisesOpError(r"Condition x > 0 did not hold"): + student = student_t.StudentT( + df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") + self.evaluate(student.mean()) def testStudentTWithAbsDfSoftplusScale(self): - with self.test_session(): - df = constant_op.constant([-3.2, -4.6]) - mu = constant_op.constant([-4.2, 3.4]) - sigma = constant_op.constant([-6.4, -8.8]) - student = student_t.StudentTWithAbsDfSoftplusScale( - df=df, loc=mu, scale=sigma) - self.assertAllClose( - math_ops.floor(self.evaluate(math_ops.abs(df))), - self.evaluate(student.df)) - self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) - self.assertAllClose( - self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) + df = constant_op.constant([-3.2, -4.6]) + mu = constant_op.constant([-4.2, 3.4]) + sigma = constant_op.constant([-6.4, -8.8]) + student = student_t.StudentTWithAbsDfSoftplusScale( + df=df, loc=mu, scale=sigma) + self.assertAllClose( + math_ops.floor(self.evaluate(math_ops.abs(df))), + self.evaluate(student.df)) + self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index bc9c267b9a..9cdcd369c1 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -50,255 +50,239 @@ class UniformTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testUniformRange(self): - with self.test_session(): - a = 3.0 - b = 10.0 - uniform = uniform_lib.Uniform(low=a, high=b) - self.assertAllClose(a, self.evaluate(uniform.low)) - self.assertAllClose(b, self.evaluate(uniform.high)) - self.assertAllClose(b - a, self.evaluate(uniform.range())) + a = 3.0 + b = 10.0 + uniform = uniform_lib.Uniform(low=a, high=b) + self.assertAllClose(a, self.evaluate(uniform.low)) + self.assertAllClose(b, self.evaluate(uniform.high)) + self.assertAllClose(b - a, self.evaluate(uniform.range())) @test_util.run_in_graph_and_eager_modes def testUniformPDF(self): - with self.test_session(): - a = constant_op.constant([-3.0] * 5 + [15.0]) - b = constant_op.constant([11.0] * 5 + [20.0]) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([-3.0] * 5 + [15.0]) + b = constant_op.constant([11.0] * 5 + [20.0]) + uniform = uniform_lib.Uniform(low=a, high=b) - a_v = -3.0 - b_v = 11.0 - x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32) + a_v = -3.0 + b_v = 11.0 + x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32) - def _expected_pdf(): - pdf = np.zeros_like(x) + 1.0 / (b_v - a_v) - pdf[x > b_v] = 0.0 - pdf[x < a_v] = 0.0 - pdf[5] = 1.0 / (20.0 - 15.0) - return pdf + def _expected_pdf(): + pdf = np.zeros_like(x) + 1.0 / (b_v - a_v) + pdf[x > b_v] = 0.0 + pdf[x < a_v] = 0.0 + pdf[5] = 1.0 / (20.0 - 15.0) + return pdf - expected_pdf = _expected_pdf() + expected_pdf = _expected_pdf() - pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob(x) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) - log_pdf = uniform.log_prob(x) - self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) + log_pdf = uniform.log_prob(x) + self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) @test_util.run_in_graph_and_eager_modes def testUniformShape(self): - with self.test_session(): - a = constant_op.constant([-3.0] * 5) - b = constant_op.constant(11.0) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([-3.0] * 5) + b = constant_op.constant(11.0) + uniform = uniform_lib.Uniform(low=a, high=b) - self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,)) - self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) - self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,)) + self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) + self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) @test_util.run_in_graph_and_eager_modes def testUniformPDFWithScalarEndpoint(self): - with self.test_session(): - a = constant_op.constant([0.0, 5.0]) - b = constant_op.constant(10.0) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([0.0, 5.0]) + b = constant_op.constant(10.0) + uniform = uniform_lib.Uniform(low=a, high=b) - x = np.array([0.0, 8.0], dtype=np.float32) - expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)]) + x = np.array([0.0, 8.0], dtype=np.float32) + expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)]) - pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob(x) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testUniformCDF(self): - with self.test_session(): - batch_size = 6 - a = constant_op.constant([1.0] * batch_size) - b = constant_op.constant([11.0] * batch_size) - a_v = 1.0 - b_v = 11.0 - x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32) + batch_size = 6 + a = constant_op.constant([1.0] * batch_size) + b = constant_op.constant([11.0] * batch_size) + a_v = 1.0 + b_v = 11.0 + x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32) - uniform = uniform_lib.Uniform(low=a, high=b) + uniform = uniform_lib.Uniform(low=a, high=b) - def _expected_cdf(): - cdf = (x - a_v) / (b_v - a_v) - cdf[x >= b_v] = 1 - cdf[x < a_v] = 0 - return cdf + def _expected_cdf(): + cdf = (x - a_v) / (b_v - a_v) + cdf[x >= b_v] = 1 + cdf[x < a_v] = 0 + return cdf - cdf = uniform.cdf(x) - self.assertAllClose(_expected_cdf(), self.evaluate(cdf)) + cdf = uniform.cdf(x) + self.assertAllClose(_expected_cdf(), self.evaluate(cdf)) - log_cdf = uniform.log_cdf(x) - self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) + log_cdf = uniform.log_cdf(x) + self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) @test_util.run_in_graph_and_eager_modes def testUniformEntropy(self): - with self.test_session(): - a_v = np.array([1.0, 1.0, 1.0]) - b_v = np.array([[1.5, 2.0, 3.0]]) - uniform = uniform_lib.Uniform(low=a_v, high=b_v) + a_v = np.array([1.0, 1.0, 1.0]) + b_v = np.array([[1.5, 2.0, 3.0]]) + uniform = uniform_lib.Uniform(low=a_v, high=b_v) - expected_entropy = np.log(b_v - a_v) - self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) + expected_entropy = np.log(b_v - a_v) + self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) @test_util.run_in_graph_and_eager_modes def testUniformAssertMaxGtMin(self): - with self.test_session(): - a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) - b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) + a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) + b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - "x < y"): - uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) - self.evaluate(uniform.low) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "x < y"): + uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) + self.evaluate(uniform.low) @test_util.run_in_graph_and_eager_modes def testUniformSample(self): - with self.test_session(): - a = constant_op.constant([3.0, 4.0]) - b = constant_op.constant(13.0) - a1_v = 3.0 - a2_v = 4.0 - b_v = 13.0 - n = constant_op.constant(100000) - uniform = uniform_lib.Uniform(low=a, high=b) - - samples = uniform.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertAllClose( - sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) - self.assertAllClose( - sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) - self.assertFalse( - np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) - self.assertFalse( - np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) + a = constant_op.constant([3.0, 4.0]) + b = constant_op.constant(13.0) + a1_v = 3.0 + a2_v = 4.0 + b_v = 13.0 + n = constant_op.constant(100000) + uniform = uniform_lib.Uniform(low=a, high=b) + + samples = uniform.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertAllClose( + sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) + self.assertAllClose( + sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) + self.assertFalse( + np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) + self.assertFalse( + np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) @test_util.run_in_graph_and_eager_modes def _testUniformSampleMultiDimensional(self): # DISABLED: Please enable this test once b/issues/30149644 is resolved. - with self.test_session(): - batch_size = 2 - a_v = [3.0, 22.0] - b_v = [13.0, 35.0] - a = constant_op.constant([a_v] * batch_size) - b = constant_op.constant([b_v] * batch_size) - - uniform = uniform_lib.Uniform(low=a, high=b) - - n_v = 100000 - n = constant_op.constant(n_v) - samples = uniform.sample(n) - self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) - - sample_values = self.evaluate(samples) - - self.assertFalse( - np.any(sample_values[:, 0, 0] < a_v[0]) or - np.any(sample_values[:, 0, 0] >= b_v[0])) - self.assertFalse( - np.any(sample_values[:, 0, 1] < a_v[1]) or - np.any(sample_values[:, 0, 1] >= b_v[1])) - - self.assertAllClose( - sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2) - self.assertAllClose( - sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) + batch_size = 2 + a_v = [3.0, 22.0] + b_v = [13.0, 35.0] + a = constant_op.constant([a_v] * batch_size) + b = constant_op.constant([b_v] * batch_size) + + uniform = uniform_lib.Uniform(low=a, high=b) + + n_v = 100000 + n = constant_op.constant(n_v) + samples = uniform.sample(n) + self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) + + sample_values = self.evaluate(samples) + + self.assertFalse( + np.any(sample_values[:, 0, 0] < a_v[0]) or + np.any(sample_values[:, 0, 0] >= b_v[0])) + self.assertFalse( + np.any(sample_values[:, 0, 1] < a_v[1]) or + np.any(sample_values[:, 0, 1] >= b_v[1])) + + self.assertAllClose( + sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2) + self.assertAllClose( + sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) @test_util.run_in_graph_and_eager_modes def testUniformMean(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) @test_util.run_in_graph_and_eager_modes def testUniformVariance(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) @test_util.run_in_graph_and_eager_modes def testUniformStd(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) @test_util.run_in_graph_and_eager_modes def testUniformNans(self): - with self.test_session(): - a = 10.0 - b = [11.0, 100.0] - uniform = uniform_lib.Uniform(low=a, high=b) + a = 10.0 + b = [11.0, 100.0] + uniform = uniform_lib.Uniform(low=a, high=b) - no_nans = constant_op.constant(1.0) - nans = constant_op.constant(0.0) / constant_op.constant(0.0) - self.assertTrue(self.evaluate(math_ops.is_nan(nans))) - with_nans = array_ops.stack([no_nans, nans]) + no_nans = constant_op.constant(1.0) + nans = constant_op.constant(0.0) / constant_op.constant(0.0) + self.assertTrue(self.evaluate(math_ops.is_nan(nans))) + with_nans = array_ops.stack([no_nans, nans]) - pdf = uniform.prob(with_nans) + pdf = uniform.prob(with_nans) - is_nan = self.evaluate(math_ops.is_nan(pdf)) - self.assertFalse(is_nan[0]) - self.assertTrue(is_nan[1]) + is_nan = self.evaluate(math_ops.is_nan(pdf)) + self.assertFalse(is_nan[0]) + self.assertTrue(is_nan[1]) @test_util.run_in_graph_and_eager_modes def testUniformSamplePdf(self): - with self.test_session(): - a = 10.0 - b = [11.0, 100.0] - uniform = uniform_lib.Uniform(a, b) - self.assertTrue( - self.evaluate( - math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) + a = 10.0 + b = [11.0, 100.0] + uniform = uniform_lib.Uniform(a, b) + self.assertTrue( + self.evaluate( + math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) @test_util.run_in_graph_and_eager_modes def testUniformBroadcasting(self): - with self.test_session(): - a = 10.0 - b = [11.0, 20.0] - uniform = uniform_lib.Uniform(a, b) + a = 10.0 + b = [11.0, 20.0] + uniform = uniform_lib.Uniform(a, b) - pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]]) - expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]]) + expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testUniformSampleWithShape(self): - with self.test_session(): - a = 10.0 - b = [11.0, 20.0] - uniform = uniform_lib.Uniform(a, b) - - pdf = uniform.prob(uniform.sample((2, 3))) - # pylint: disable=bad-continuation - expected_pdf = [ - [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], - [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], - ] - # pylint: enable=bad-continuation - self.assertAllClose(expected_pdf, self.evaluate(pdf)) - - pdf = uniform.prob(uniform.sample()) - expected_pdf = [1.0, 0.1] - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + a = 10.0 + b = [11.0, 20.0] + uniform = uniform_lib.Uniform(a, b) + + pdf = uniform.prob(uniform.sample((2, 3))) + # pylint: disable=bad-continuation + expected_pdf = [ + [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], + [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], + ] + # pylint: enable=bad-continuation + self.assertAllClose(expected_pdf, self.evaluate(pdf)) + + pdf = uniform.prob(uniform.sample()) + expected_pdf = [1.0, 0.1] + self.assertAllClose(expected_pdf, self.evaluate(pdf)) def testFullyReparameterized(self): a = constant_op.constant(0.1) diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 61faa8466e..27d652c2c6 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase): w = array_ops.placeholder(dtypes.float32) feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20], z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]} - with self.test_session(): + with self.cached_session(): with ops.control_dependencies([du.assert_integer_form(x)]): array_ops.identity(x).eval(feed_dict=feed_dict) @@ -122,58 +122,52 @@ class GetLogitsAndProbsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testImproperArguments(self): - with self.test_session(): - with self.assertRaises(ValueError): - du.get_logits_and_probs(logits=None, probs=None) + with self.assertRaises(ValueError): + du.get_logits_and_probs(logits=None, probs=None) - with self.assertRaises(ValueError): - du.get_logits_and_probs(logits=[0.1], probs=[0.1]) + with self.assertRaises(ValueError): + du.get_logits_and_probs(logits=[0.1], probs=[0.1]) @test_util.run_in_graph_and_eager_modes def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - logits=logits, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + logits=logits, validate_args=True) - self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) - self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) + self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) + self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) @test_util.run_in_graph_and_eager_modes def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - logits=logits, multidimensional=True, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + logits=logits, multidimensional=True, validate_args=True) - self.assertAllClose(self.evaluate(new_p), p) - self.assertAllClose(self.evaluate(new_logits), logits) + self.assertAllClose(self.evaluate(new_p), p) + self.assertAllClose(self.evaluate(new_logits), logits) @test_util.run_in_graph_and_eager_modes def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - probs=p, validate_args=True) + new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True) - self.assertAllClose(_logit(p), self.evaluate(new_logits)) - self.assertAllClose(p, self.evaluate(new_p)) + self.assertAllClose(_logit(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) @test_util.run_in_graph_and_eager_modes def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - probs=p, multidimensional=True, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + probs=p, multidimensional=True, validate_args=True) - self.assertAllClose(np.log(p), self.evaluate(new_logits)) - self.assertAllClose(p, self.evaluate(new_p)) + self.assertAllClose(np.log(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgs(self): @@ -183,29 +177,23 @@ class GetLogitsAndProbsTest(test.TestCase): # Component greater than 1. p3 = [2, 0.2, 0.5, 0.3, .2] - with self.test_session(): - _, prob = du.get_logits_and_probs( - probs=p, validate_args=True) - self.evaluate(prob) - - with self.assertRaisesOpError("Condition x >= 0"): - _, prob = du.get_logits_and_probs( - probs=p2, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p, validate_args=True) + self.evaluate(prob) - _, prob = du.get_logits_and_probs( - probs=p2, validate_args=False) + with self.assertRaisesOpError("Condition x >= 0"): + _, prob = du.get_logits_and_probs(probs=p2, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError("probs has components greater than 1"): - _, prob = du.get_logits_and_probs( - probs=p3, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p2, validate_args=False) + self.evaluate(prob) - _, prob = du.get_logits_and_probs( - probs=p3, validate_args=False) + with self.assertRaisesOpError("probs has components greater than 1"): + _, prob = du.get_logits_and_probs(probs=p3, validate_args=True) self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p3, validate_args=False) + self.evaluate(prob) + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) @@ -216,41 +204,39 @@ class GetLogitsAndProbsTest(test.TestCase): # Does not sum to 1. p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32) - with self.test_session(): - _, prob = du.get_logits_and_probs( - probs=p, multidimensional=True) - self.evaluate(prob) - - with self.assertRaisesOpError("Condition x >= 0"): - _, prob = du.get_logits_and_probs( - probs=p2, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p, multidimensional=True) + self.evaluate(prob) + with self.assertRaisesOpError("Condition x >= 0"): _, prob = du.get_logits_and_probs( - probs=p2, multidimensional=True, validate_args=False) + probs=p2, multidimensional=True, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError( - "(probs has components greater than 1|probs does not sum to 1)"): - _, prob = du.get_logits_and_probs( - probs=p3, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p2, multidimensional=True, validate_args=False) + self.evaluate(prob) + with self.assertRaisesOpError( + "(probs has components greater than 1|probs does not sum to 1)"): _, prob = du.get_logits_and_probs( - probs=p3, multidimensional=True, validate_args=False) + probs=p3, multidimensional=True, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError("probs does not sum to 1"): - _, prob = du.get_logits_and_probs( - probs=p4, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p3, multidimensional=True, validate_args=False) + self.evaluate(prob) + with self.assertRaisesOpError("probs does not sum to 1"): _, prob = du.get_logits_and_probs( - probs=p4, multidimensional=True, validate_args=False) + probs=p4, multidimensional=True, validate_args=True) self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p4, multidimensional=True, validate_args=False) + self.evaluate(prob) + def testProbsMultidimShape(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): p = array_ops.ones([int(2**11+1)], dtype=np.float16) du.get_logits_and_probs( @@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase): prob.eval(feed_dict={p: np.ones([int(2**11+1)])}) def testLogitsMultidimShape(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): l = array_ops.ones([int(2**11+1)], dtype=np.float16) du.get_logits_and_probs( @@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase): class EmbedCheckCategoricalEventShapeTest(test.TestCase): def testTooSmall(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): param = array_ops.ones([1], dtype=np.float16) checked_param = du.embed_check_categorical_event_shape( @@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): checked_param.eval(feed_dict={param: np.ones([1])}) def testTooLarge(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16) checked_param = du.embed_check_categorical_event_shape( @@ -310,18 +296,17 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): - with self.test_session(): - param = ops.convert_to_tensor( - np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), - dtype=dtypes.qint16) - with self.assertRaises(TypeError): - du.embed_check_categorical_event_shape(param) + param = ops.convert_to_tensor( + np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), + dtype=dtypes.qint16) + with self.assertRaises(TypeError): + du.embed_check_categorical_event_shape(param) class EmbedCheckIntegerCastingClosedTest(test.TestCase): def testCorrectlyAssertsNonnegative(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements must be non-negative"): x = array_ops.placeholder(dtype=dtypes.float16) x_checked = du.embed_check_integer_casting_closed( @@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)}) def testCorrectlyAssersIntegerForm(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements must be int16-equivalent."): x = array_ops.placeholder(dtype=dtypes.float16) x_checked = du.embed_check_integer_casting_closed( @@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)}) def testCorrectlyAssertsLargestPossibleInteger(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements cannot exceed 32767."): x = array_ops.placeholder(dtype=dtypes.int32) x_checked = du.embed_check_integer_casting_closed( @@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)}) def testCorrectlyAssertsSmallestPossibleInteger(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements cannot be smaller than 0."): x = array_ops.placeholder(dtype=dtypes.int32) x_checked = du.embed_check_integer_casting_closed( @@ -365,29 +350,27 @@ class LogCombinationsTest(test.TestCase): log_combs = np.log(special.binom(n, k)) - with self.test_session(): - n = np.array(n, dtype=np.float32) - counts = [[1., 1], [2., 3], [4., 8], [11, 4]] - log_binom = du.log_combinations(n, counts) - self.assertEqual([4], log_binom.get_shape()) - self.assertAllClose(log_combs, self.evaluate(log_binom)) + n = np.array(n, dtype=np.float32) + counts = [[1., 1], [2., 3], [4., 8], [11, 4]] + log_binom = du.log_combinations(n, counts) + self.assertEqual([4], log_binom.get_shape()) + self.assertAllClose(log_combs, self.evaluate(log_binom)) def testLogCombinationsShape(self): # Shape [2, 2] n = [[2, 5], [12, 15]] - with self.test_session(): - n = np.array(n, dtype=np.float32) - # Shape [2, 2, 4] - counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]] - log_binom = du.log_combinations(n, counts) - self.assertEqual([2, 2], log_binom.get_shape()) + n = np.array(n, dtype=np.float32) + # Shape [2, 2, 4] + counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]] + log_binom = du.log_combinations(n, counts) + self.assertEqual([2, 2], log_binom.get_shape()) class DynamicShapeTest(test.TestCase): def testSameDynamicShape(self): - with self.test_session(): + with self.cached_session(): scalar = constant_op.constant(2.0) scalar1 = array_ops.placeholder(dtype=dtypes.float32) @@ -497,22 +480,21 @@ class RotateTransposeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testRollStatic(self): - with self.test_session(): - if context.executing_eagerly(): - error_message = r"Attempt to convert a value \(None\)" - else: - error_message = "None values not supported." - with self.assertRaisesRegexp(ValueError, error_message): - du.rotate_transpose(None, 1) - for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): - for shift in np.arange(-5, 5): - y = du.rotate_transpose(x, shift) - self.assertAllEqual( - self._np_rotate_transpose(x, shift), self.evaluate(y)) - self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) + if context.executing_eagerly(): + error_message = r"Attempt to convert a value \(None\)" + else: + error_message = "None values not supported." + with self.assertRaisesRegexp(ValueError, error_message): + du.rotate_transpose(None, 1) + for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): + for shift in np.arange(-5, 5): + y = du.rotate_transpose(x, shift) + self.assertAllEqual( + self._np_rotate_transpose(x, shift), self.evaluate(y)) + self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) def testRollDynamic(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) shift = array_ops.placeholder(dtypes.int32) for x_value in (np.ones( @@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase): class PickVectorTest(test.TestCase): def testCorrectlyPicksVector(self): - with self.test_session(): + with self.cached_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( @@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase): def testDynamicRankEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicRankEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(1, rank.eval(feed_dict={x: []})) def testDynamicRankEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(0, rank.eval(feed_dict={x: 1})) @@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase): def testDynamicShapeEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicShapeEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []})) def testDynamicShapeEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1})) @@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase): def testDynamicValueEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.zeros((2, 3)), value.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicValueEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []})) def testDynamicValueEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1})) @@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase): def _run_test(self, x_, use_deferred_shape=False, **kwargs): x_ = np.asarray(x_) - with self.test_session() as sess: + with self.cached_session() as sess: static_shape = None if use_deferred_shape else x_.shape x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) # Add `zeros_like(x)` such that x's value and gradient are identical. We @@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest): def _run_test(self, x_, use_deferred_shape=False, **kwargs): x_ = np.asarray(x_) - with self.test_session() as sess: + with self.cached_session() as sess: static_shape = None if use_deferred_shape else x_.shape x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.) @@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase): logx_ = np.array([[0., -1, 1000.], [0, 1, -1000.], [-5, 0, 5]]) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) expected = math_ops.reduce_logsumexp(logx, axis=-1) grad_expected = gradients_impl.gradients(expected, logx)[0] @@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase): [1, -2, 1], [1, 0, 1]]) expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) w = constant_op.constant(w_) actual, actual_sgn = du.reduce_weighted_logsumexp( @@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase): [1, 0, 1]]) expected, _ = self._reduce_weighted_logsumexp( logx_, w_, axis=-1, keep_dims=True) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) w = constant_op.constant(w_) actual, actual_sgn = du.reduce_weighted_logsumexp( @@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase): def testDocString(self): """This test verifies the correctness of the docstring examples.""" - with self.test_session(): + with self.cached_session(): x = constant_op.constant([[0., 0, 0], [0, 0, 0]]) @@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase): use_gpu=True) def testGradient(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], shape=[2, 5], @@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase): self.assertLess(err, 1e-4) def testInverseSoftplusGradientNeverNan(self): - with self.test_session(): + with self.cached_session(): # Note that this range contains both zero and inf. x = constant_op.constant(np.logspace(-8, 6).astype(np.float16)) y = du.softplus_inverse(x) @@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase): self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads)) def testInverseSoftplusGradientFinite(self): - with self.test_session(): + with self.cached_session(): # This range of x is all finite, and so is 1 / x. So the # gradient and its approximations should be finite as well. x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16)) diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 1e76ad7476..3ddb5e06c9 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -59,42 +59,48 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldl_Simple(self): - with self.test_session(): - elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") + elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - r = functional_ops.foldl( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems) - self.assertAllEqual(208, self.evaluate(r)) + r = functional_ops.foldl( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems) + self.assertAllEqual(208, self.evaluate(r)) - r = functional_ops.foldl( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems, - initializer=10) - self.assertAllEqual(880, self.evaluate(r)) + r = functional_ops.foldl( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems, + initializer=10) + self.assertAllEqual(880, self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testFoldl_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array([1, -1.0]) - r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array([1, -1.0]) + r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) + r_value = self.evaluate(r) - self.assertAllEqual(22, r_value[0]) - self.assertAllEqual(20, r_value[1]) + self.assertAllEqual(22, r_value[0]) + self.assertAllEqual(20, r_value[1]) @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), - initializer) - self.assertAllEqual(1, self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), + initializer) + self.assertAllEqual(1, self.evaluate(r)) + + @test_util.run_in_graph_and_eager_modes + def testFoldl_MultiInputDifferentDimsSingleOutput(self): + elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]]) + other_elems = np.array([-1.0, 1.0]) + initializer = np.array([0.0, 0.0, 0.0]) + r = functional_ops.foldl(lambda a, x: a + x[0] * x[1], + (elems, other_elems), initializer) + self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r)) def testFoldl_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -114,42 +120,39 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldr_Simple(self): - with self.test_session(): - elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") + elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - r = functional_ops.foldr( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems) - self.assertAllEqual(450, self.evaluate(r)) + r = functional_ops.foldr( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems) + self.assertAllEqual(450, self.evaluate(r)) - r = functional_ops.foldr( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems, - initializer=10) - self.assertAllEqual(1282, self.evaluate(r)) + r = functional_ops.foldr( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems, + initializer=10) + self.assertAllEqual(1282, self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testFoldr_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array([1, -1.0]) - r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array([1, -1.0]) + r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) + r_value = self.evaluate(r) - self.assertAllEqual(22, r_value[0]) - self.assertAllEqual(20, r_value[1]) + self.assertAllEqual(22, r_value[0]) + self.assertAllEqual(20, r_value[1]) @test_util.run_in_graph_and_eager_modes def testFoldr_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), - initializer) - self.assertAllEqual(1, self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), + initializer) + self.assertAllEqual(1, self.evaluate(r)) def testFoldr_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -169,7 +172,7 @@ class FunctionalOpsTest(test.TestCase): # pylint: disable=unnecessary-lambda def testFold_Grad(self): - with self.test_session(): + with self.cached_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") v = constant_op.constant(2.0, name="v") r = functional_ops.foldl( @@ -185,16 +188,15 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMap_Simple(self): - with self.test_session(): - nums = [1, 2, 3, 4, 5, 6] - elems = constant_op.constant(nums, name="data") - r = functional_ops.map_fn( - lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) - self.assertAllEqual( - np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = functional_ops.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) + self.assertAllEqual( + np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) def testMapSparseTensor(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): functional_ops.map_fn( lambda x: x, @@ -211,7 +213,7 @@ class FunctionalOpsTest(test.TestCase): functional_ops.map_fn(lambda x: x, 1) def testMap_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: def double_scoped(x): """2x with a dummy 2 that is scoped.""" @@ -242,7 +244,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(doubles, self.evaluate(r)) def testMap_Grad(self): - with self.test_session(): + with self.cached_session(): param = constant_op.constant(2.0) elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") y = functional_ops.map_fn( @@ -254,142 +256,131 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMap_SimpleNotTensor(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) - self.assertAllEqual( - np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) + self.assertAllEqual( + np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testMap_SingleInputMultiOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: ((x + 3) * 2, -(x + 3) * 2), - nums, - dtype=(dtypes.int64, dtypes.int64)) - self.assertEqual(2, len(r)) - self.assertEqual((6,), r[0].get_shape()) - self.assertEqual((6,), r[1].get_shape()) - received = self.evaluate(r) - self.assertAllEqual((nums + 3) * 2, received[0]) - self.assertAllEqual(-(nums + 3) * 2, received[1]) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: ((x + 3) * 2, -(x + 3) * 2), + nums, + dtype=(dtypes.int64, dtypes.int64)) + self.assertEqual(2, len(r)) + self.assertEqual((6,), r[0].get_shape()) + self.assertEqual((6,), r[1].get_shape()) + received = self.evaluate(r) + self.assertAllEqual((nums + 3) * 2, received[0]) + self.assertAllEqual(-(nums + 3) * 2, received[1]) @test_util.run_in_graph_and_eager_modes def testMap_MultiOutputMismatchedDtype(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - with self.assertRaisesRegexp( - TypeError, r"two structures don't have the same nested structure"): - # lambda emits tuple, but dtype is a list - functional_ops.map_fn( - lambda x: ((x + 3) * 2, -(x + 3) * 2), - nums, - dtype=[dtypes.int64, dtypes.int64]) + nums = np.array([1, 2, 3, 4, 5, 6]) + with self.assertRaisesRegexp( + TypeError, r"two structures don't have the same nested structure"): + # lambda emits tuple, but dtype is a list + functional_ops.map_fn( + lambda x: ((x + 3) * 2, -(x + 3) * 2), + nums, + dtype=[dtypes.int64, dtypes.int64]) @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSingleOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), - dtype=dtypes.int64) - self.assertEqual((6,), r.get_shape()) - received = self.evaluate(r) - self.assertAllEqual(nums * nums + (-nums), received) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), + dtype=dtypes.int64) + self.assertEqual((6,), r.get_shape()) + received = self.evaluate(r) + self.assertAllEqual(nums * nums + (-nums), received) @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSameStructureOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), - (nums, (2 * nums, -nums))) - r = [r[0], r[1][0], r[1][1]] - self.assertEqual((6,), r[0].get_shape()) - self.assertEqual((6,), r[1].get_shape()) - self.assertEqual((6,), r[2].get_shape()) - received = self.evaluate(r) - self.assertAllEqual(2 * nums, received[0]) - self.assertAllEqual(-nums, received[1]) - self.assertAllEqual(nums, received[2]) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), + (nums, (2 * nums, -nums))) + r = [r[0], r[1][0], r[1][1]] + self.assertEqual((6,), r[0].get_shape()) + self.assertEqual((6,), r[1].get_shape()) + self.assertEqual((6,), r[2].get_shape()) + received = self.evaluate(r) + self.assertAllEqual(2 * nums, received[0]) + self.assertAllEqual(-nums, received[1]) + self.assertAllEqual(nums, received[2]) @test_util.run_in_graph_and_eager_modes def testScan_Simple(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") - v = constant_op.constant(2.0, name="v") + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") + v = constant_op.constant(2.0, name="v") - # pylint: disable=unnecessary-lambda - r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) - self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) + # pylint: disable=unnecessary-lambda + r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) + self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) - r = functional_ops.scan( - lambda a, x: math_ops.multiply(a, x), elems, initializer=v) - self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) - # pylint: enable=unnecessary-lambda + r = functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), elems, initializer=v) + self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + # pylint: enable=unnecessary-lambda @test_util.run_in_graph_and_eager_modes def testScan_Reverse(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") - v = constant_op.constant(2.0, name="v") - - # pylint: disable=unnecessary-lambda - r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, - reverse=True) - self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) - r = functional_ops.scan( - lambda a, x: math_ops.multiply(a, x), elems, initializer=v, - reverse=True) - self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], - self.evaluate(r)) - # pylint: enable=unnecessary-lambda + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") + v = constant_op.constant(2.0, name="v") + + # pylint: disable=unnecessary-lambda + r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, + reverse=True) + self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) + r = functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), elems, initializer=v, + reverse=True) + self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], + self.evaluate(r)) + # pylint: enable=unnecessary-lambda @test_util.run_in_graph_and_eager_modes def testScan_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = (np.array(1.0), np.array(-1.0)) - r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, - initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = (np.array(1.0), np.array(-1.0)) + r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, + initializer) + r_value = self.evaluate(r) - self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) - self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) + self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) + self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - # Multiply a * 1 each time - r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), - (elems + 1, -elems), initializer) - self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + # Multiply a * 1 each time + r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), + (elems + 1, -elems), initializer) + self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSameTypeOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), - (elems, -elems)) - r_value = self.evaluate(r) - self.assertAllEqual(np.cumsum(elems), r_value[0]) - self.assertAllEqual(np.cumsum(-elems), r_value[1]) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), + (elems, -elems)) + r_value = self.evaluate(r) + self.assertAllEqual(np.cumsum(elems), r_value[0]) + self.assertAllEqual(np.cumsum(-elems), r_value[1]) @test_util.run_in_graph_and_eager_modes def testScan_MultiOutputMismatchedInitializer(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - # Multiply a * 1 each time - with self.assertRaisesRegexp( - ValueError, "two structures don't have the same nested structure"): - functional_ops.scan(lambda a, x: (a, -a), elems, initializer) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + # Multiply a * 1 each time + with self.assertRaisesRegexp( + ValueError, "two structures don't have the same nested structure"): + functional_ops.scan(lambda a, x: (a, -a), elems, initializer) def testScan_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -411,30 +402,29 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testScanFoldl_Nested(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") - inner_elems = constant_op.constant([0.5, 0.5], name="data") - - def r_inner(a, x): - return functional_ops.foldl( - lambda b, y: b * y * x, inner_elems, initializer=a) - - r = functional_ops.scan(r_inner, elems) - - # t == 0 (returns 1) - # t == 1, a == 1, x == 2 (returns 1) - # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 - # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 - # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) - # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 - # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 - # t == 3, a == 2.25, x == 4 (returns 9) - # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 - # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 - self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") + inner_elems = constant_op.constant([0.5, 0.5], name="data") + + def r_inner(a, x): + return functional_ops.foldl( + lambda b, y: b * y * x, inner_elems, initializer=a) + + r = functional_ops.scan(r_inner, elems) + + # t == 0 (returns 1) + # t == 1, a == 1, x == 2 (returns 1) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 + # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 + # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 + # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 + # t == 3, a == 2.25, x == 4 (returns 9) + # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 + # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 + self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) def testScan_Control(self): - with self.test_session() as sess: + with self.cached_session() as sess: s = array_ops.placeholder(dtypes.float32, shape=[None]) b = array_ops.placeholder(dtypes.bool) @@ -445,7 +435,7 @@ class FunctionalOpsTest(test.TestCase): b: True})) def testScan_Grad(self): - with self.test_session(): + with self.cached_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") v = constant_op.constant(2.0, name="v") @@ -470,22 +460,20 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - def fn(_, current_input): - return current_input + def fn(_, current_input): + return current_input - initializer = constant_op.constant([0, 0, 0]) - y = functional_ops.foldl(fn, x, initializer=initializer) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + initializer = constant_op.constant([0, 0, 0]) + y = functional_ops.foldl(fn, x, initializer=initializer) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) @test_util.run_in_graph_and_eager_modes def testMapShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - y = functional_ops.map_fn(lambda e: e, x) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + y = functional_ops.map_fn(lambda e: e, x) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) def testMapUnknownShape(self): x = array_ops.placeholder(dtypes.float32) @@ -494,15 +482,14 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMapEmptyScalar(self): - with self.test_session(): - map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) - self.assertAllEqual([0], map_return.get_shape().dims) - self.assertAllEqual([0], self.evaluate(map_return).shape) + map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) + self.assertAllEqual([0], map_return.get_shape().dims) + self.assertAllEqual([0], self.evaluate(map_return).shape) # TODO(akshayka): this test fails in eager: the iterable is of length 0 so # so the body of the while loop never executes def testMapEmptyTensor(self): - with self.test_session(): + with self.cached_session(): map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]), constant_op.constant([])) self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) @@ -510,20 +497,19 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testScanShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - def fn(_, current_input): - return current_input + def fn(_, current_input): + return current_input - initializer = constant_op.constant([0, 0, 0]) - y = functional_ops.scan(fn, x, initializer=initializer) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + initializer = constant_op.constant([0, 0, 0]) + y = functional_ops.scan(fn, x, initializer=initializer) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) # TODO(akshayka): this test fails in eager: the iterable is of length 0 so # so the body of the while loop never executes def testScanEmptyTensor(self): - with self.test_session(): + with self.cached_session(): x = functional_ops.scan( lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4])) self.assertAllEqual([0, 2, 4], x.get_shape()) @@ -540,7 +526,7 @@ class FunctionalOpsTest(test.TestCase): self.assertIs(None, y.get_shape().dims) def testScanVaryingShape(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2]) x_t = array_ops.transpose(x) # scan over dimension 0 (with shape None) @@ -619,7 +605,7 @@ class FunctionalOpsTest(test.TestCase): remote_op = functional_ops.remote_call( args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, [6]) @@ -643,7 +629,7 @@ class FunctionalOpsTest(test.TestCase): f=_remote_fn, target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, 9.0) @@ -667,7 +653,7 @@ class FunctionalOpsTest(test.TestCase): f=_remote_fn, target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, 9.0) @@ -686,7 +672,7 @@ class FunctionalOpsTest(test.TestCase): remote_op = functional_ops.remote_call( args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") - with self.test_session() as sess: + with self.cached_session() as sess: ret = sess.run(remote_op) self.assertAllEqual(ret, [b"a"]) diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 9b6aee64aa..0f5607712b 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -170,9 +170,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) - @test_util.run_in_graph_and_eager_modes def testGraphStack(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -182,9 +181,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]]) - @test_util.run_in_graph_and_eager_modes def testGraphStackInLoop(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): t1 = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -200,9 +198,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32) self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3]) - @test_util.run_in_graph_and_eager_modes def testGraphStackSwitchDtype(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): list_ = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -222,9 +219,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes def testGraphStackInLoopSwitchDtype(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): t1 = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -476,6 +472,47 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.evaluate(t_full_zeros), np.zeros( (2,), dtype=dtype.as_numpy_dtype)) + @test_util.run_in_graph_and_eager_modes + def testZerosLikeVariant(self): + for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16, + dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32, + dtypes.float64, dtypes.complex64, dtypes.complex128, + dtypes.bool): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.variant, element_shape=scalar_shape()) + + sub_l = list_ops.empty_tensor_list( + element_dtype=dtype, element_shape=scalar_shape()) + l = list_ops.tensor_list_push_back(l, sub_l) + sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( + 1, dtype=dtype)) + l = list_ops.tensor_list_push_back(l, sub_l) + sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( + 2, dtype=dtype)) + l = list_ops.tensor_list_push_back(l, sub_l) + + # l : [[], + # [1], + # [1, 2]] + # + # l_zeros : [[], + # [0], + # [0, 0]] + l_zeros = array_ops.zeros_like(l) + + outputs = [] + for _ in range(3): + l_zeros, out = list_ops.tensor_list_pop_back( + l_zeros, element_dtype=dtypes.variant) + outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype)) + + # Note: `outputs` contains popped values so the order is reversed. + self.assertAllEqual(self.evaluate(outputs[2]), []) + self.assertAllEqual( + self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype)) + self.assertAllEqual( + self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 50154a45a8..79fcbaad43 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -61,7 +61,7 @@ class PyFuncTest(test.TestCase): for dtype in [dtypes.float16, dtypes.float32, dtypes.float64, dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16, dtypes.int32, dtypes.int64]: - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1, dtype=dtype) y = constant_op.constant(2, dtype=dtype) z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype)) @@ -71,7 +71,7 @@ class PyFuncTest(test.TestCase): def sub_func(x, y): return x - y for dtype in [dtypes.complex64, dtypes.complex128]: - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1 + 1j, dtype=dtype) y = constant_op.constant(2 - 2j, dtype=dtype) z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype)) @@ -81,21 +81,21 @@ class PyFuncTest(test.TestCase): def and_func(x, y): return x and y dtype = dtypes.bool - with self.test_session(): + with self.cached_session(): x = constant_op.constant(True, dtype=dtype) y = constant_op.constant(False, dtype=dtype) z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype)) self.assertEqual(z, False) def testSingleType(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32)) self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32)) def testScalar(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32) z = self.evaluate( @@ -103,7 +103,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32)) def testArray(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([1.0, 2.0], dtypes.float64) y = constant_op.constant([2.0, 3.0], dtypes.float64) z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64])) @@ -111,14 +111,14 @@ class PyFuncTest(test.TestCase): np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) def testComplexType(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1 + 2j, dtypes.complex64) y = constant_op.constant(3 + 4j, dtypes.complex64) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64)) self.assertAllClose(z, np_func(1 + 2j, 3 + 4j)) def testRFFT(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) def rfft(x): @@ -128,7 +128,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.])) def testPythonLiteral(self): - with self.test_session(): + with self.cached_session(): def literal(x): return 1.0 if float(x) == 0.0 else 0.0 @@ -138,7 +138,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, 1.0) def testList(self): - with self.test_session(): + with self.cached_session(): def list_func(x): return [x, x + 1] @@ -150,7 +150,7 @@ class PyFuncTest(test.TestCase): def testTuple(self): # returns a tuple - with self.test_session(): + with self.cached_session(): def tuple_func(x): return x, x + 1 @@ -161,7 +161,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, [0.0, 1.0]) # returns a tuple, Tout and inp a tuple - with self.test_session(): + with self.cached_session(): x = constant_op.constant(0.0, dtypes.float64) y = self.evaluate( script_ops.py_func(tuple_func, (x,), @@ -176,7 +176,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant([b"hello", b"hi"], dtypes.string) y = self.evaluate( script_ops.py_func(read_fixed_length_numpy_strings, [], @@ -193,7 +193,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant(["hello", "hi"], dtypes.string) y = self.evaluate( script_ops.py_func(read_fixed_length_numpy_strings, [], @@ -210,7 +210,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant(["hello", "hi"], dtypes.string) y, = script_ops.py_func(read_object_array, [], [dtypes.string]) @@ -219,19 +219,19 @@ class PyFuncTest(test.TestCase): def testStringPadding(self): correct = [b"this", b"is", b"a", b"test"] - with self.test_session(): + with self.cached_session(): s, = script_ops.py_func(lambda: [correct], [], [dtypes.string]) self.assertAllEqual(s.eval(), correct) def testStringPaddingAreConvertedToBytes(self): inp = ["this", "is", "a", "test"] correct = [b"this", b"is", b"a", b"test"] - with self.test_session(): + with self.cached_session(): s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) self.assertAllEqual(s.eval(), correct) def testLarge(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.zeros([1000000], dtype=np.float32) y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32]) z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32]) @@ -239,12 +239,12 @@ class PyFuncTest(test.TestCase): sess.run([y[0].op, z[0].op]) def testNoInput(self): - with self.test_session(): + with self.cached_session(): x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64)) self.assertAllClose(x, 42.0) def testAlias(self): - with self.test_session(): + with self.cached_session(): np_array = np.array([1.0, 2.0], dtype=np.float32) tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) @@ -252,7 +252,7 @@ class PyFuncTest(test.TestCase): self.assertAllEqual(np_array, [1.0, 2.0]) def testReturnUnicodeString(self): - with self.test_session(): + with self.cached_session(): correct = u"你好 世界" def unicode_string(): @@ -262,7 +262,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(z.eval(), correct.encode("utf8")) def testBadNumpyReturnType(self): - with self.test_session(): + with self.cached_session(): def bad(): # Structured numpy arrays aren't supported. @@ -275,7 +275,7 @@ class PyFuncTest(test.TestCase): y.eval() def testBadReturnType(self): - with self.test_session(): + with self.cached_session(): def bad(): # Non-string python objects aren't supported. @@ -288,7 +288,7 @@ class PyFuncTest(test.TestCase): z.eval() def testReturnInput(self): - with self.test_session(): + with self.cached_session(): def ident(x): return x[0] @@ -303,7 +303,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) def testStateful(self): - # Not using self.test_session(), which disables optimization. + # Not using self.cached_session(), which disables optimization. with session_lib.Session() as sess: producer = iter(range(3)) x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64]) @@ -312,7 +312,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(sess.run(x), 2) def testStateless(self): - # Not using self.test_session(), which disables optimization. + # Not using self.cached_session(), which disables optimization. with session_lib.Session() as sess: producer = iter(range(3)) x, = script_ops.py_func( @@ -331,7 +331,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(None, ops.get_gradient_function(y.op)) def testCOrder(self): - with self.test_session(): + with self.cached_session(): val = [[1, 2], [3, 4]] x, = script_ops.py_func(lambda: np.array(val, order="F"), [], [dtypes.int64]) @@ -339,7 +339,7 @@ class PyFuncTest(test.TestCase): def testParallel(self): # Tests that tf.py_func's can run in parallel if they release the GIL. - with self.test_session() as session: + with self.cached_session() as session: q = queue.Queue(1) def blocking_put(): @@ -375,7 +375,7 @@ class PyFuncTest(test.TestCase): def value(self): return self._value - with self.test_session(): + with self.cached_session(): s = State() op = s.increment(constant_op.constant(2, dtypes.int64)) ret = self.evaluate(op) @@ -389,7 +389,7 @@ class PyFuncTest(test.TestCase): f = script_ops.py_func( do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(f), []) def _testExceptionHandling(self, py_exp, tf_exp, eager=False): @@ -417,21 +417,22 @@ class PyFuncTest(test.TestCase): else: f = script_ops.py_func(raise_exception, [], []) - with self.test_session(): - with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): - self.evaluate(f) + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): + self.evaluate(f) def testExceptionHandling(self): - self._testExceptionHandling(ValueError, errors.InvalidArgumentError) - self._testExceptionHandling(TypeError, errors.InvalidArgumentError) - self._testExceptionHandling(StopIteration, errors.OutOfRangeError) - self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError) - self._testExceptionHandling(NotImplementedError, errors.UnimplementedError) + with self.cached_session(): + self._testExceptionHandling(ValueError, errors.InvalidArgumentError) + self._testExceptionHandling(TypeError, errors.InvalidArgumentError) + self._testExceptionHandling(StopIteration, errors.OutOfRangeError) + self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError) + self._testExceptionHandling(NotImplementedError, + errors.UnimplementedError) - class WeirdError(Exception): - pass + class WeirdError(Exception): + pass - self._testExceptionHandling(WeirdError, errors.UnknownError) + self._testExceptionHandling(WeirdError, errors.UnknownError) # ----- Tests shared by py_func and eager_py_func ----- def testCleanup(self): @@ -452,7 +453,7 @@ class PyFuncTest(test.TestCase): # (see #18292) _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) - + # Call garbage collector to enforce deletion. make_graphs() ops.reset_default_graph() @@ -610,7 +611,7 @@ class PyFuncTest(test.TestCase): func=log_huber, inp=[x, m], Tout=dtypes.float32) dy_dx = gradients_impl.gradients(y, x)[0] - with self.test_session() as sess: + with self.cached_session() as sess: # Takes the first branch of log_huber. y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) self.assertEqual(y, 1.0) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index d0ed08933d..f90545f84c 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -54,7 +54,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(gc.garbage)) def testHandleDtypeShapeMatch(self): - with self.test_session(): + with self.cached_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op( @@ -123,7 +123,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy())) def testGraphDeepCopy(self): - with self.test_session(): + with self.cached_session(): init_value = np.ones((4, 4, 4)) variable = resource_variable_ops.ResourceVariable(init_value, name="init") @@ -145,13 +145,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): # variable graph. def testFetchHandle(self): - with self.test_session(): + with self.cached_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") self.assertGreater(len(handle.eval()), 0) def testCachedValueReadBeforeWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0") sess.run(v.initializer) value, _ = sess.run([v, v.assign_add(1.0)]) @@ -492,7 +492,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): # TODO(alive): how should this work in Eager mode? def testInitFn(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, dtype=dtypes.float32) self.assertEqual(v.handle.op.colocation_groups(), @@ -569,11 +569,11 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(2.0, self.evaluate(v.value())) def testVariableDefInitializedInstances(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v_def = resource_variable_ops.ResourceVariable( initial_value=constant_op.constant(3.0)).to_proto() - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # v describes a VariableDef-based variable without an initial value. v = resource_variable_ops.ResourceVariable(variable_def=v_def) self.assertEqual(3.0, sess.run(v.initialized_value())) @@ -584,7 +584,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(1.0, v.initialized_value().eval()) v_def.ClearField("initial_value_name") - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # Restoring a legacy VariableDef proto that does not have # initial_value_name set should still work. v = resource_variable_ops.ResourceVariable(variable_def=v_def) @@ -615,17 +615,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes def testSparseRead(self): - with self.test_session(): - init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) - v = resource_variable_ops.ResourceVariable( - constant_op.constant(init_value, dtype=dtypes.int32), name="var3") - self.evaluate(variables.global_variables_initializer()) + init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) + v = resource_variable_ops.ResourceVariable( + constant_op.constant(init_value, dtype=dtypes.int32), name="var3") + self.evaluate(variables.global_variables_initializer()) - value = self.evaluate(v.sparse_read([0, 3, 1, 2])) - self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value) + value = self.evaluate(v.sparse_read([0, 3, 1, 2])) + self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value) def testToFromProto(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() @@ -686,7 +685,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): handle, ignore_lookup_error=True)) def testAssignDifferentShapes(self): - with self.test_session() as sess, variable_scope.variable_scope( + with self.cached_session() as sess, variable_scope.variable_scope( "foo", use_resource=True): var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32) placeholder = array_ops.placeholder(dtypes.float32) @@ -728,7 +727,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): _ = w.value().op.get_attr("_class") def testSharedName(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() @@ -746,7 +745,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval() def testSharedNameWithNamescope(self): - with self.test_session(): + with self.cached_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access @@ -774,7 +773,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape)) def testSetInitialValue(self): - with self.test_session(): + with self.cached_session(): # Initialize variable with a value different from the initial value passed # in the constructor. v = resource_variable_ops.ResourceVariable(2.0) diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 562d11f0b0..a28cdc3b26 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -197,7 +197,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) if not in_eager_mode: @@ -217,7 +217,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) if not in_eager_mode: @@ -246,7 +246,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) state = (state[0], state[1].stack()) @@ -321,7 +321,7 @@ class RNNTest(test.TestCase): self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3) def testRNNWithKerasSimpleRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -354,7 +354,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state), batch) def testRNNWithKerasGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -387,7 +387,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state), batch) def testRNNWithKerasLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -424,7 +424,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state[1]), batch) def testRNNWithStackKerasCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -465,7 +465,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(s), batch) def testStaticRNNWithKerasSimpleRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -567,7 +567,7 @@ class RNNTest(test.TestCase): rnn_cell_impl.GRUCell( 32, kernel_initializer="ones", dtype=dtypes.float32) ]: - with self.test_session(): + with self.cached_session(): x = keras.Input((None, 5)) layer = keras.layers.RNN(cell) y = layer(x) diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index c4e9c982b5..c6a6b2a7fa 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -180,16 +180,16 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name def _get_func_graphs(if_op): - """Returns `_FuncGraph`s for the input op branches. + """Returns `FuncGraph`s for the input op branches. Args: if_op: The _If Operation. Returns: - A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + A 2-tuple of the `FuncGraph`s of the then_branch and else_branch. """ def _get_func_graph_for_branch(branch_name): - """Generates and returns a _FuncGraph for the given branch.""" + """Generates and returns a FuncGraph for the given branch.""" inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] func_name = if_op.get_attr(branch_name).name @@ -197,7 +197,7 @@ def _get_func_graphs(if_op): # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `if_op.graph` as its - # `outer_graph`. This resembles how the `_FuncGraph` was built in the + # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with if_op.graph.as_default(): @@ -221,7 +221,7 @@ def _grad_fn(func_graph, grads): func_graph's outputs w.r.t. its inputs. Args: - func_graph: function._FuncGraph. The corresponding forward-pass function. + func_graph: function.FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: @@ -259,7 +259,7 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): - """Returns the _FuncGraph representation of _grad_fn.""" + """Returns the FuncGraph representation of _grad_fn.""" return _function.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {}) @@ -277,8 +277,8 @@ def _resolve_grad_inputs(cond_graph, grad_graph): functions, this is always possible. Args: - cond_graph: function._FuncGraph. The forward-pass function. - grad_graph: function._FuncGraph. The gradients function. + cond_graph: function.FuncGraph. The forward-pass function. + grad_graph: function.FuncGraph. The gradients function. Returns: A list of inputs tensors to be passed to grad_graph. @@ -313,7 +313,7 @@ def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: - func_graph: function._FuncGraph + func_graph: function.FuncGraph Returns: The name of the new TF_Function. @@ -365,8 +365,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params): There is no merging of params. Args: - true_graph: function._FuncGraph - false_graph: function._FuncGraph + true_graph: function.FuncGraph + false_graph: function.FuncGraph true_params: a list of Tensors from true_graph false_params: a list of Tensors from false_graph @@ -391,8 +391,8 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): graph to avoid duplicating shared arguments. Args: - true_graph: function._FuncGraph - false_graph: function._FuncGraph + true_graph: function.FuncGraph + false_graph: function.FuncGraph true_inputs: a list of Tensors in the outer graph. The inputs for true_graph. false_inputs: a list of Tensors in the outer graph. The inputs for @@ -421,7 +421,7 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): _create_dummy_params(false_graph, true_only_inputs) + [false_input_to_param[t] for t in false_only_inputs]) - # Rewrite the _FuncGraphs' state to reflect the new inputs. + # Rewrite the FuncGraphs' state to reflect the new inputs. true_graph.captures = collections.OrderedDict(zip(new_inputs, true_graph.inputs)) false_graph.captures = collections.OrderedDict(zip(new_inputs, @@ -434,7 +434,7 @@ def _create_dummy_params(func_graph, template_tensors): """Creates tensors in func_graph to represent template_tensors. Args: - func_graph: function._FuncGraph. + func_graph: function.FuncGraph. template_tensors: a list of tensors in the outer graph. Returns: @@ -451,27 +451,16 @@ def _get_grad_fn_name(func_graph): Ensures this name is unique in the entire hierarchy. Args: - func_graph: The _FuncGraph. + func_graph: The FuncGraph. Returns: A string, the name to use for the gradient function. """ name = "%s_grad" % func_graph.name - - base_name = name - counter = 1 - has_conflict = True - while has_conflict: - curr_graph = func_graph.outer_graph - has_conflict = curr_graph._is_function(name) - while not has_conflict and isinstance(curr_graph, _function.FuncGraph): - curr_graph = curr_graph.outer_graph - has_conflict = curr_graph._is_function(name) - if has_conflict: - name = "%s_%s" % (base_name, counter) - counter += 1 - - return name + outer_most_graph = func_graph + while isinstance(outer_most_graph, _function.FuncGraph): + outer_most_graph = outer_most_graph.outer_graph + return outer_most_graph.unique_name(name) def _check_same_outputs(true_graph, false_graph): diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index ddf9442cd2..578e7b7dd2 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -446,6 +446,24 @@ class Distribution(_BaseDistribution): self._graph_parents = graph_parents self._name = name + @property + def _parameters(self): + return self._parameter_dict + + @_parameters.setter + def _parameters(self, value): + """Intercept assignments to self._parameters to avoid reference cycles. + + Parameters are often created using locals(), so we need to clean out any + references to `self` before assigning it to an attribute. + + Args: + value: A dictionary of parameters to assign to the `_parameters` property. + """ + if "self" in value: + del value["self"] + self._parameter_dict = value + @classmethod def param_shapes(cls, sample_shape, name="DistributionParamShapes"): """Shapes of parameters given the desired shape of a call to `sample()`. diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 7f851e3646..f25ed700d6 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -41,6 +41,7 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ #include <complex> +#include <vector> #include "tensorflow/stream_executor/host_or_device_scalar.h" #include "tensorflow/stream_executor/lib/array_slice.h" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index eb41deee13..9f6dcd8fdb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -9,16 +9,14 @@ tf_proto { type: TYPE_STRING } field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } - field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index e565b903d2..f3a515163d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -132,17 +132,15 @@ tf_proto { type: TYPE_STRING } field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } - field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt index eb41deee13..9f6dcd8fdb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt @@ -9,16 +9,14 @@ tf_proto { type: TYPE_STRING } field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } - field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt index e565b903d2..f3a515163d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt @@ -132,17 +132,15 @@ tf_proto { type: TYPE_STRING } field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } - field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } } diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu index f05c7a4809..a4cad4b6c6 100644 --- a/tensorflow/tools/ci_build/Dockerfile.gpu +++ b/tensorflow/tools/ci_build/Dockerfile.gpu @@ -30,3 +30,4 @@ RUN mkdir /usr/local/cuda-9.0/lib && \ # Configure the build for our CUDA configuration. ENV TF_NEED_CUDA 1 +ENV TF_NEED_TENSORRT 1 diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 9640810533..179fc42d60 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -67,6 +67,12 @@ apt-get install -y --no-install-recommends \ zip \ zlib1g-dev +apt-get update && \ + apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \ + apt-get update && \ + apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \ + apt-get install libnvinfer-dev=4.1.2-1+cuda9.0 + # populate the database updatedb diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh index f958b3c9b7..60c974c36b 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh @@ -52,6 +52,7 @@ ${DOCKER_BINARY} run \ -e "PYTHON_BIN_PATH=/usr/bin/python" \ -e "TF_NEED_HDFS=0" \ -e "TF_NEED_CUDA=${TF_NEED_CUDA}" \ + -e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \ -e "TF_NEED_OPENCL_SYCL=0" \ "${DOCKER_IMAGE}" \ "/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh" diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 997afc6ac7..549056c6c4 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -947,6 +947,7 @@ class _ClassPageInfo(object): self._aliases = None self._doc = None self._guides = None + self._namedtuplefields = None self._bases = None self._properties = [] @@ -1030,6 +1031,17 @@ class _ClassPageInfo(object): self._guides = guides @property + def namedtuplefields(self): + return self._namedtuplefields + + def set_namedtuplefields(self, py_class): + if issubclass(py_class, tuple): + if all( + hasattr(py_class, attr) + for attr in ('_asdict', '_fields', '_make', '_replace')): + self._namedtuplefields = py_class._fields + + @property def bases(self): """Returns a list of `_LinkInfo` objects pointing to the class' parents.""" return self._bases @@ -1066,7 +1078,15 @@ class _ClassPageInfo(object): @property def properties(self): """Returns a list of `_PropertyInfo` describing the class' properties.""" - return self._properties + props_dict = {prop.short_name: prop for prop in self._properties} + props = [] + if self.namedtuplefields: + for field in self.namedtuplefields: + props.append(props_dict.pop(field)) + + props.extend(sorted(props_dict.values())) + + return props def _add_property(self, short_name, full_name, obj, doc): """Adds a `_PropertyInfo` entry to the `properties` list. @@ -1077,6 +1097,9 @@ class _ClassPageInfo(object): obj: The property object itself doc: The property's parsed docstring, a `_DocstringInfo`. """ + # Hide useless namedtuple docs-trings + if re.match('Alias for field number [0-9]+', doc.docstring): + doc = doc._replace(docstring='', brief='') property_info = _PropertyInfo(short_name, full_name, obj, doc) self._properties.append(property_info) @@ -1156,6 +1179,7 @@ class _ClassPageInfo(object): py_class: The class object being documented parser_config: An instance of ParserConfig. """ + self.set_namedtuplefields(py_class) doc_path = documentation_path(self.full_name) relative_path = os.path.relpath( path='.', start=os.path.dirname(doc_path) or '.') diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 9f6b185e81..71e96afa10 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import os import sys @@ -190,6 +191,50 @@ class ParserTest(googletest.TestCase): # Make sure this file is contained as the definition location. self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) + def test_namedtuple_field_order(self): + namedtupleclass = collections.namedtuple('namedtupleclass', + {'z', 'y', 'x', 'w', 'v', 'u'}) + + index = { + 'namedtupleclass': namedtupleclass, + 'namedtupleclass.u': namedtupleclass.u, + 'namedtupleclass.v': namedtupleclass.v, + 'namedtupleclass.w': namedtupleclass.w, + 'namedtupleclass.x': namedtupleclass.x, + 'namedtupleclass.y': namedtupleclass.y, + 'namedtupleclass.z': namedtupleclass.z, + } + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + tree = {'namedtupleclass': {'u', 'v', 'w', 'x', 'y', 'z'}} + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') + + page_info = parser.docs_for_object( + full_name='namedtupleclass', + py_object=namedtupleclass, + parser_config=parser_config) + + # Each namedtiple field has a docstring of the form: + # 'Alias for field number ##'. These props are returned sorted. + + def sort_key(prop_info): + return int(prop_info.obj.__doc__.split(' ')[-1]) + + self.assertSequenceEqual(page_info.properties, + sorted(page_info.properties, key=sort_key)) + def test_docs_for_class_should_skip(self): class Parent(object): @@ -736,6 +781,5 @@ class TestGenerateSignature(googletest.TestCase): sig = parser._generate_signature(example_fun, reverse_index={}) self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"]) - if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index aecf753a58..448f246e0e 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -136,7 +136,7 @@ def _build_class_page(page_info): if page_info.properties: parts.append('## Properties\n\n') - for prop_info in sorted(page_info.properties): + for prop_info in page_info.properties: h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n' parts.append(h3.format(short_name=prop_info.short_name)) diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc index c8dc2a7c4d..d97496cbeb 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -92,7 +92,7 @@ Status ExtractMinMaxRecords(const string& log_file_name, if (!str_util::EndsWith(name_string, print_suffix)) { continue; } - string name = std::string( + string name( name_string.substr(0, name_string.size() - print_suffix.size())); records->push_back({name, min, max}); } diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index dd95779a1f..b8d6ba00de 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -42,8 +42,8 @@ class SparsifyGatherTest : public ::testing::Test { const std::vector<NodeDef*>& inputs, GraphDef* graph_def, bool control_dep = false) { NodeDef* node_def = graph_def->add_node(); - node_def->set_name(std::string(name)); - node_def->set_op(std::string(op)); + node_def->set_name(string(name)); + node_def->set_op(string(op)); if (!control_dep) { std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { node_def->add_input(input->name()); diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index 5cae8f8d8f..7efe450710 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -65,19 +65,19 @@ Status ParseTransformParameters(const string& transforms_string, .GetResult(&remaining, &transform_name); if (!found_transform_name) { return errors::InvalidArgument("Looking for transform name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_NAME; } else { // Add a transform with no parameters. - params_list->push_back({std::string(transform_name), func_parameters}); + params_list->push_back({string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } } else if (state == TRANSFORM_PARAM_NAME) { if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) { - params_list->push_back({std::string(transform_name), func_parameters}); + params_list->push_back({string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } else { @@ -92,13 +92,13 @@ Status ParseTransformParameters(const string& transforms_string, if (!found_parameter_name) { return errors::InvalidArgument( "Looking for parameter name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_VALUE; } else { return errors::InvalidArgument("Looking for =, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } } } else if (state == TRANSFORM_PARAM_VALUE) { @@ -120,10 +120,9 @@ Status ParseTransformParameters(const string& transforms_string, } if (!found_parameter_value) { return errors::InvalidArgument("Looking for parameter name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } - func_parameters[std::string(parameter_name)].push_back( - std::string(parameter_value)); + func_parameters[string(parameter_name)].emplace_back(parameter_value); // Eat up any trailing quotes. Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match); Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index cb084e49b7..c715380aae 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -93,7 +93,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix, } else { *prefix = ""; } - *node_name = std::string(node_name_piece); + *node_name = string(node_name_piece); } string NodeNameFromInput(const string& input_name) { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 997725d865..742f33f68e 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", ], - sha256 = "b8f4ffbcaeea345e2245fd7028c7e960d71c2a2007c20bbfc5d79ecc86992a5e", - strip_prefix = "llvm-67bd0d9a0f5597f57f272061fd70f24dffb3d223", + sha256 = "c7252290a113f694cccbb4b325c67b56f3aa6f5b3044524302c0e79db2da7e2a", + strip_prefix = "llvm-dc6d9ec3646865125d057b6f515b4543df79920a", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) |