aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/js
diff options
context:
space:
mode:
authorGravatar Nick Kreeger <kreeger@google.com>2018-08-24 10:45:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 10:49:14 -0700
commit90030cc1ef1ce88cc9dc017ac99e495fab65077d (patch)
treeb0aae86927d64b0b627378d3463f737bf4593365 /tensorflow/js
parent37b2b0eb613b6c3c66b96374851cfd95050346a0 (diff)
Generate TypeScript Op attribute values for "type" and "int" OpDef attribute types.
This is an incremental change to first introduce updates to the TypeScript internal library and references to building OpDef attribute structs that the TensorFlow.js Node runtime uses. For now, this change introduces basic "type" and "int" attr types. I'll continue to roll more types and complicated examples in upcoming changes. PiperOrigin-RevId: 210121141
Diffstat (limited to 'tensorflow/js')
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc93
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc138
2 files changed, 178 insertions, 53 deletions
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
index babf55cd5f..fb93bb6d8e 100644
--- a/tensorflow/js/ops/ts_op_gen.cc
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -38,6 +38,15 @@ struct ArgDefs {
const ApiDef::Arg& api_def_arg;
};
+// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
+struct OpAttrs {
+ OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
+ : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
+
+ const OpDef::AttrDef& op_def_attr;
+ const ApiDef::Attr& api_def_attr;
+};
+
// Helper class to generate TypeScript code for a given OpDef:
class GenTypeScriptOp {
public:
@@ -49,8 +58,12 @@ class GenTypeScriptOp {
private:
void ProcessArgs();
+ void ProcessAttrs();
+ void AddAttrForArg(const string& attr, int arg_index);
+ string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddMethodSignature();
+ void AddOpAttrs();
void AddMethodReturnAndClose();
const OpDef& op_def_;
@@ -62,6 +75,13 @@ class GenTypeScriptOp {
// Holds in-order vector of Op inputs:
std::vector<ArgDefs> input_op_args_;
+ // Holds in-order vector of Op attributes:
+ std::vector<OpAttrs> op_attrs_;
+
+ // Stores attributes-to-arguments by name:
+ typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
+ AttrArgIdxMap attr_arg_idx_map_;
+
// Holds number of outputs:
int num_outputs_;
};
@@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {}
string GenTypeScriptOp::Code() {
ProcessArgs();
+ ProcessAttrs();
// Generate exported function for Op:
AddMethodSignature();
+ AddOpAttrs();
AddMethodReturnAndClose();
strings::StrAppend(&result_, "\n");
@@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() {
<< api_def_.arg_order(i);
continue;
}
+
+ // Map attr names to arg indexes:
+ if (!op_def_arg->type_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_attr(), i);
+ } else if (!op_def_arg->type_list_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_list_attr(), i);
+ }
+ if (!op_def_arg->number_attr().empty()) {
+ AddAttrForArg(op_def_arg->number_attr(), i);
+ }
+
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
}
num_outputs_ = api_def_.out_arg_size();
}
+void GenTypeScriptOp::ProcessAttrs() {
+ for (int i = 0; i < op_def_.attr_size(); i++) {
+ op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
+ }
+}
+
+void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
+ // Keep track of attributes-to-arguments by name. These will be used for
+ // construction Op attributes that require information about the inputs.
+ auto iter = attr_arg_idx_map_.find(attr);
+ if (iter == attr_arg_idx_map_.end()) {
+ attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
+ } else {
+ iter->second.push_back(arg_index);
+ }
+}
+
+string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
+ string inputs;
+ auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
+ if (arg_list != attr_arg_idx_map_.end()) {
+ for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
+ ++iter) {
+ strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
+ }
+ }
+ return inputs;
+}
+
void GenTypeScriptOp::AddMethodSignature() {
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
"(");
@@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() {
}
}
+void GenTypeScriptOp::AddOpAttrs() {
+ strings::StrAppend(&result_, " const opAttrs = [\n");
+
+ bool is_first = true;
+ for (auto& attr : op_attrs_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ",\n");
+ }
+
+ // Append 4 spaces to start:
+ strings::StrAppend(&result_, " ");
+
+ if (attr.op_def_attr.type() == "type") {
+ // Type OpAttributes can be generated from a helper function:
+ strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
+ attr.op_def_attr.name(), "', ",
+ InputForAttr(attr.op_def_attr), ")");
+ } else if (attr.op_def_attr.type() == "int") {
+ strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
+ strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
+ strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
+ ".length}");
+ }
+ }
+ strings::StrAppend(&result_, "\n ];\n");
+}
+
void GenTypeScriptOp::AddMethodReturnAndClose() {
strings::StrAppend(&result_, " return null;\n}\n");
}
@@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) {
// This file is MACHINE GENERATED! Do not edit
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)header";
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
index 9a85c021b0..03241689b5 100644
--- a/tensorflow/js/ops/ts_op_gen_test.cc
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
<< "'" << s << "' does not contain '" << expected << "'";
}
-// TODO(kreeger): Add multiple outputs here?
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
@@ -79,50 +78,15 @@ op {
summary: "Summary for op Foo."
description: "Description for op Foo."
}
-op {
- name: "DeprecatedFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output"
- description: "Description for output."
- type: DT_FLOAT
- }
- deprecation {
- explanation: "Deprecated."
- }
-}
-op {
- name: "MultiOutputFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output1"
- description: "Description for output 1."
- type: DT_FLOAT
- }
- output_arg {
- name: "output2"
- description: "Description for output 2."
- type: DT_FLOAT
- }
- summary: "Summary for op MultiOutputFoo."
- description: "Description for op MultiOutputFoo."
-}
)";
// Generate TypeScript code
-// @param api_def_str TODO doc me.
-void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
+ string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
- protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ protobuf::TextFormat::ParseFromString(
+ op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) {
@@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -160,12 +124,10 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -179,34 +141,106 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
+ const string op_def = R"(
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ description: "Description for input."
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
+ const string op_def = R"(
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
+TEST(TsOpGenTest, OpAttrs) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expectedFooAttrs = R"(
+ const opAttrs = [
+ createTensorsTypeOpAttr('T', images),
+ {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
+ ];
+)";
+
+ ExpectContainsStr(ts_file_text, expectedFooAttrs);
+}
+
} // namespace
} // namespace tensorflow