diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-31 15:17:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-31 15:20:38 -0700 |
commit | 4f6074494d4bf77daac5749224017615bfca239f (patch) | |
tree | 377d06acf45f3f15f00b4de24caa62872a255e0e /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 269a4ed1c27251b55cffe578b7bd969ec5975487 (diff) |
Move reorder-cast-and-transpose optimization to optimization stage.
PiperOrigin-RevId: 198788352
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 154 |
1 files changed, 96 insertions, 58 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 0edea16aac..ca3f84a81d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -194,8 +194,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); } -bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set<string>& nodes_to_preserve) { @@ -1866,6 +1864,100 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage { } }; +// Reorder Cast and Transpose if beneficial. +// +// A common pattern after the layout optimizer is casting an uint8 NHWC +// image to float before transposing it to NCHW. It is beneficial to reorder +// the cast and the transpose to make the transpose process smaller amount +// of data. This optimization converts +// Transpose(Cast(image, dst_type), perm) +// to +// Cast(Transpose(image, perm), dst_type) +// when sizeof(image.type) < sizeof(dst_type). +// +// TODO(jingyue): This optimization can be generalized to a cast followed by +// a chain of ops that merely reorder elements (e.g. Reshape and +// DepthToSpace). +class ReorderCastAndTranspose : public ArithmeticOptimizerStage { + public: + explicit ReorderCastAndTranspose(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReorderCastAndTranspose", ctx, ctx_ext) {} + ~ReorderCastAndTranspose() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsTranspose(*node) && NodeIsOnCpuOrGpu(node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeDef* transpose = node; + + // Verify that input to Transpose is the Cast op. + NodeDef* cast; + TF_RETURN_IF_ERROR(GetInputNode(transpose->input(0), &cast)); + if (!IsCast(*cast)) return Status::OK(); + + // Input to the Cast-Transpose chain. + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(cast->input(0), &input)); + + const DataType src_type = GetSourceDataType(*cast); + const DataType dst_type = GetDestinationDataType(*cast); + + const string src_type_name = DataTypeString(src_type); + const string dst_type_name = DataTypeString(dst_type); + + // Check if nodes were not already optimized. + const string optimized_cast_name = + OptimizedNodeName(ParseNodeScopeAndName(cast->name()), dst_type_name); + const string optimized_transpose_name = OptimizedNodeName( + ParseNodeScopeAndName(transpose->name()), src_type_name); + + bool is_already_optimized = + ctx().node_map->NodeExists(optimized_transpose_name) || + ctx().node_map->NodeExists(optimized_cast_name); + + if (IsNumberType(src_type) && IsNumberType(dst_type) && + DataTypeSize(src_type) < DataTypeSize(dst_type) && + !is_already_optimized) { + NodeDef* new_transpose = AddCopyNode(optimized_transpose_name, transpose); + (*new_transpose->mutable_attr())["T"].set_type(src_type); + new_transpose->set_input(0, cast->input(0)); + + ctx().node_map->AddOutput(input->name(), new_transpose->name()); + ctx().node_map->AddOutput(NodeName(new_transpose->input(1)), + new_transpose->name()); + + NodeDef* new_cast = AddCopyNode(optimized_cast_name, cast); + new_cast->set_input(0, new_transpose->name()); + ctx().node_map->AddOutput(new_transpose->name(), new_cast->name()); + + AddToOptimizationQueue(new_transpose); + ForwardControlDependencies(new_transpose, {cast, node}); + + *simplified_node_name = new_cast->name(); + } + + return Status::OK(); + } + + private: + // This optimization can be dangerous on devices other than CPU and + // GPU. The transpose might not be implemented for image.type, or + // might be slower with image.type than with dst_type. + bool NodeIsOnCpuOrGpu(const NodeDef* node) const { + using str_util::StrContains; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU)); + } + + bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } +}; + } // namespace class UniqueNodes { @@ -2118,62 +2210,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) { - - if (node->op() == "Transpose") { - // Reorder Cast and Transpose if beneficial. - // - // A common pattern after the layout optimizer is casting an uint8 NHWC - // image to float before transposing it to NCHW. It is beneficial to reorder - // the cast and the transpose to make the transpose process smaller amount - // of data. This optimization converts - // Transpose(Cast(image, dst_type), perm) - // to - // Cast(Transpose(image, perm), dst_type) - // when sizeof(image.type) < sizeof(dst_type). - // - // TODO(jingyue): This optimization can be generalized to a cast followed by - // a chain of ops that merely reorder elements (e.g. Reshape and - // DepthToSpace). - const NodeDef* transpose = node; - string dontcare; - string device; - // This optimization can be dangerous on devices other than CPU and GPU. The - // transpose might not be implemented for image.type, or might be slower - // with image.type than with dst_type. - if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare, - &device) && - (str_util::StrContains(device, DEVICE_CPU) || - str_util::StrContains(device, DEVICE_GPU))) { - const NodeDef* cast = node_map_->GetNode(transpose->input(0)); - if (cast->op() == "Cast") { - const NodeDef* input = node_map_->GetNode(cast->input(0)); - const DataType src_type = GetSourceDataType(*cast); - const DataType dst_type = GetDestinationDataType(*cast); - if (IsNumberType(src_type) && IsNumberType(dst_type) && - DataTypeSize(src_type) < DataTypeSize(dst_type) && - !OptimizedNodeExists(*cast, DataTypeString(dst_type)) && - !OptimizedNodeExists(*transpose, DataTypeString(src_type))) { - NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type), - /*copy_node=*/true); - (*new_transpose->mutable_attr())["T"].set_type(src_type); - new_transpose->set_input(0, cast->input(0)); - node_map_->AddOutput(input->name(), new_transpose->name()); - node_map_->AddOutput(NodeName(new_transpose->input(1)), - new_transpose->name()); - - NodeDef* new_cast = - AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true); - new_cast->set_input(0, new_transpose->name()); - node_map_->AddOutput(new_transpose->name(), new_cast->name()); - - nodes_to_simplify->PushBack(new_transpose); - ForwardControlDependencies(new_transpose, {cast, node}); - return new_cast->name(); - } - } - } - } - // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize @@ -2462,6 +2498,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext); + if (options_.reorder_cast_and_transpose) + pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) |