aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/js
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/js')
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc93
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc138
2 files changed, 178 insertions, 53 deletions
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
index babf55cd5f..fb93bb6d8e 100644
--- a/tensorflow/js/ops/ts_op_gen.cc
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -38,6 +38,15 @@ struct ArgDefs {
const ApiDef::Arg& api_def_arg;
};
+// Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
+struct OpAttrs {
+ OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
+ : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
+
+ const OpDef::AttrDef& op_def_attr;
+ const ApiDef::Attr& api_def_attr;
+};
+
// Helper class to generate TypeScript code for a given OpDef:
class GenTypeScriptOp {
public:
@@ -49,8 +58,12 @@ class GenTypeScriptOp {
private:
void ProcessArgs();
+ void ProcessAttrs();
+ void AddAttrForArg(const string& attr, int arg_index);
+ string InputForAttr(const OpDef::AttrDef& op_def_attr);
void AddMethodSignature();
+ void AddOpAttrs();
void AddMethodReturnAndClose();
const OpDef& op_def_;
@@ -62,6 +75,13 @@ class GenTypeScriptOp {
// Holds in-order vector of Op inputs:
std::vector<ArgDefs> input_op_args_;
+ // Holds in-order vector of Op attributes:
+ std::vector<OpAttrs> op_attrs_;
+
+ // Stores attributes-to-arguments by name:
+ typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
+ AttrArgIdxMap attr_arg_idx_map_;
+
// Holds number of outputs:
int num_outputs_;
};
@@ -73,9 +93,11 @@ GenTypeScriptOp::~GenTypeScriptOp() {}
string GenTypeScriptOp::Code() {
ProcessArgs();
+ ProcessAttrs();
// Generate exported function for Op:
AddMethodSignature();
+ AddOpAttrs();
AddMethodReturnAndClose();
strings::StrAppend(&result_, "\n");
@@ -96,12 +118,52 @@ void GenTypeScriptOp::ProcessArgs() {
<< api_def_.arg_order(i);
continue;
}
+
+ // Map attr names to arg indexes:
+ if (!op_def_arg->type_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_attr(), i);
+ } else if (!op_def_arg->type_list_attr().empty()) {
+ AddAttrForArg(op_def_arg->type_list_attr(), i);
+ }
+ if (!op_def_arg->number_attr().empty()) {
+ AddAttrForArg(op_def_arg->number_attr(), i);
+ }
+
input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
}
num_outputs_ = api_def_.out_arg_size();
}
+void GenTypeScriptOp::ProcessAttrs() {
+ for (int i = 0; i < op_def_.attr_size(); i++) {
+ op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
+ }
+}
+
+void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
+ // Keep track of attributes-to-arguments by name. These will be used for
+ // construction Op attributes that require information about the inputs.
+ auto iter = attr_arg_idx_map_.find(attr);
+ if (iter == attr_arg_idx_map_.end()) {
+ attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
+ } else {
+ iter->second.push_back(arg_index);
+ }
+}
+
+string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
+ string inputs;
+ auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
+ if (arg_list != attr_arg_idx_map_.end()) {
+ for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
+ ++iter) {
+ strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
+ }
+ }
+ return inputs;
+}
+
void GenTypeScriptOp::AddMethodSignature() {
strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
"(");
@@ -131,6 +193,35 @@ void GenTypeScriptOp::AddMethodSignature() {
}
}
+void GenTypeScriptOp::AddOpAttrs() {
+ strings::StrAppend(&result_, " const opAttrs = [\n");
+
+ bool is_first = true;
+ for (auto& attr : op_attrs_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ",\n");
+ }
+
+ // Append 4 spaces to start:
+ strings::StrAppend(&result_, " ");
+
+ if (attr.op_def_attr.type() == "type") {
+ // Type OpAttributes can be generated from a helper function:
+ strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
+ attr.op_def_attr.name(), "', ",
+ InputForAttr(attr.op_def_attr), ")");
+ } else if (attr.op_def_attr.type() == "int") {
+ strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
+ strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
+ strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
+ ".length}");
+ }
+ }
+ strings::StrAppend(&result_, "\n ];\n");
+}
+
void GenTypeScriptOp::AddMethodReturnAndClose() {
strings::StrAppend(&result_, " return null;\n}\n");
}
@@ -162,7 +253,7 @@ void StartFile(WritableFile* ts_file) {
// This file is MACHINE GENERATED! Do not edit
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)header";
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
index 9a85c021b0..03241689b5 100644
--- a/tensorflow/js/ops/ts_op_gen_test.cc
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
<< "'" << s << "' does not contain '" << expected << "'";
}
-// TODO(kreeger): Add multiple outputs here?
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
@@ -79,50 +78,15 @@ op {
summary: "Summary for op Foo."
description: "Description for op Foo."
}
-op {
- name: "DeprecatedFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output"
- description: "Description for output."
- type: DT_FLOAT
- }
- deprecation {
- explanation: "Deprecated."
- }
-}
-op {
- name: "MultiOutputFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output1"
- description: "Description for output 1."
- type: DT_FLOAT
- }
- output_arg {
- name: "output2"
- description: "Description for output 2."
- type: DT_FLOAT
- }
- summary: "Summary for op MultiOutputFoo."
- description: "Description for op MultiOutputFoo."
-}
)";
// Generate TypeScript code
-// @param api_def_str TODO doc me.
-void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
+ string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
- protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ protobuf::TextFormat::ParseFromString(
+ op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) {
@@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -160,12 +124,10 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -179,34 +141,106 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
+ const string op_def = R"(
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ description: "Description for input."
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
+ const string op_def = R"(
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
+TEST(TsOpGenTest, OpAttrs) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expectedFooAttrs = R"(
+ const opAttrs = [
+ createTensorsTypeOpAttr('T', images),
+ {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
+ ];
+)";
+
+ ExpectContainsStr(ts_file_text, expectedFooAttrs);
+}
+
} // namespace
} // namespace tensorflow