aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 15:17:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 15:20:38 -0700
commit4f6074494d4bf77daac5749224017615bfca239f (patch)
tree377d06acf45f3f15f00b4de24caa62872a255e0e /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent269a4ed1c27251b55cffe578b7bd969ec5975487 (diff)
Move reorder-cast-and-transpose optimization to optimization stage.
PiperOrigin-RevId: 198788352
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc154
1 files changed, 96 insertions, 58 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 0edea16aac..ca3f84a81d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -194,8 +194,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
}
-bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
-
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@@ -1866,6 +1864,100 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage {
}
};
+// Reorder Cast and Transpose if beneficial.
+//
+// A common pattern after the layout optimizer is casting an uint8 NHWC
+// image to float before transposing it to NCHW. It is beneficial to reorder
+// the cast and the transpose to make the transpose process smaller amount
+// of data. This optimization converts
+// Transpose(Cast(image, dst_type), perm)
+// to
+// Cast(Transpose(image, perm), dst_type)
+// when sizeof(image.type) < sizeof(dst_type).
+//
+// TODO(jingyue): This optimization can be generalized to a cast followed by
+// a chain of ops that merely reorder elements (e.g. Reshape and
+// DepthToSpace).
+class ReorderCastAndTranspose : public ArithmeticOptimizerStage {
+ public:
+ explicit ReorderCastAndTranspose(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReorderCastAndTranspose", ctx, ctx_ext) {}
+ ~ReorderCastAndTranspose() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsTranspose(*node) && NodeIsOnCpuOrGpu(node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeDef* transpose = node;
+
+ // Verify that input to Transpose is the Cast op.
+ NodeDef* cast;
+ TF_RETURN_IF_ERROR(GetInputNode(transpose->input(0), &cast));
+ if (!IsCast(*cast)) return Status::OK();
+
+ // Input to the Cast-Transpose chain.
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(cast->input(0), &input));
+
+ const DataType src_type = GetSourceDataType(*cast);
+ const DataType dst_type = GetDestinationDataType(*cast);
+
+ const string src_type_name = DataTypeString(src_type);
+ const string dst_type_name = DataTypeString(dst_type);
+
+ // Check if nodes were not already optimized.
+ const string optimized_cast_name =
+ OptimizedNodeName(ParseNodeScopeAndName(cast->name()), dst_type_name);
+ const string optimized_transpose_name = OptimizedNodeName(
+ ParseNodeScopeAndName(transpose->name()), src_type_name);
+
+ bool is_already_optimized =
+ ctx().node_map->NodeExists(optimized_transpose_name) ||
+ ctx().node_map->NodeExists(optimized_cast_name);
+
+ if (IsNumberType(src_type) && IsNumberType(dst_type) &&
+ DataTypeSize(src_type) < DataTypeSize(dst_type) &&
+ !is_already_optimized) {
+ NodeDef* new_transpose = AddCopyNode(optimized_transpose_name, transpose);
+ (*new_transpose->mutable_attr())["T"].set_type(src_type);
+ new_transpose->set_input(0, cast->input(0));
+
+ ctx().node_map->AddOutput(input->name(), new_transpose->name());
+ ctx().node_map->AddOutput(NodeName(new_transpose->input(1)),
+ new_transpose->name());
+
+ NodeDef* new_cast = AddCopyNode(optimized_cast_name, cast);
+ new_cast->set_input(0, new_transpose->name());
+ ctx().node_map->AddOutput(new_transpose->name(), new_cast->name());
+
+ AddToOptimizationQueue(new_transpose);
+ ForwardControlDependencies(new_transpose, {cast, node});
+
+ *simplified_node_name = new_cast->name();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // This optimization can be dangerous on devices other than CPU and
+ // GPU. The transpose might not be implemented for image.type, or
+ // might be slower with image.type than with dst_type.
+ bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
+ using str_util::StrContains;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
+ }
+
+ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
+};
+
} // namespace
class UniqueNodes {
@@ -2118,62 +2210,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
// ArithmeticOptimizerStage
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
-
- if (node->op() == "Transpose") {
- // Reorder Cast and Transpose if beneficial.
- //
- // A common pattern after the layout optimizer is casting an uint8 NHWC
- // image to float before transposing it to NCHW. It is beneficial to reorder
- // the cast and the transpose to make the transpose process smaller amount
- // of data. This optimization converts
- // Transpose(Cast(image, dst_type), perm)
- // to
- // Cast(Transpose(image, perm), dst_type)
- // when sizeof(image.type) < sizeof(dst_type).
- //
- // TODO(jingyue): This optimization can be generalized to a cast followed by
- // a chain of ops that merely reorder elements (e.g. Reshape and
- // DepthToSpace).
- const NodeDef* transpose = node;
- string dontcare;
- string device;
- // This optimization can be dangerous on devices other than CPU and GPU. The
- // transpose might not be implemented for image.type, or might be slower
- // with image.type than with dst_type.
- if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare,
- &device) &&
- (str_util::StrContains(device, DEVICE_CPU) ||
- str_util::StrContains(device, DEVICE_GPU))) {
- const NodeDef* cast = node_map_->GetNode(transpose->input(0));
- if (cast->op() == "Cast") {
- const NodeDef* input = node_map_->GetNode(cast->input(0));
- const DataType src_type = GetSourceDataType(*cast);
- const DataType dst_type = GetDestinationDataType(*cast);
- if (IsNumberType(src_type) && IsNumberType(dst_type) &&
- DataTypeSize(src_type) < DataTypeSize(dst_type) &&
- !OptimizedNodeExists(*cast, DataTypeString(dst_type)) &&
- !OptimizedNodeExists(*transpose, DataTypeString(src_type))) {
- NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type),
- /*copy_node=*/true);
- (*new_transpose->mutable_attr())["T"].set_type(src_type);
- new_transpose->set_input(0, cast->input(0));
- node_map_->AddOutput(input->name(), new_transpose->name());
- node_map_->AddOutput(NodeName(new_transpose->input(1)),
- new_transpose->name());
-
- NodeDef* new_cast =
- AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true);
- new_cast->set_input(0, new_transpose->name());
- node_map_->AddOutput(new_transpose->name(), new_cast->name());
-
- nodes_to_simplify->PushBack(new_transpose);
- ForwardControlDependencies(new_transpose, {cast, node});
- return new_cast->name();
- }
- }
- }
- }
-
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
@@ -2462,6 +2498,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
+ if (options_.reorder_cast_and_transpose)
+ pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext);
if (options_.hoist_cwise_unary_chains)
pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
if (options_.convert_sqrt_div_to_rsqrt_mul)