diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
index d9da12b62..6831af10b 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
+import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@@ -128,7 +129,7 @@ public class DTypeTests extends BaseDL4JTest {
throw new RuntimeException(e);
}
- if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) {
+ if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
continue;
}
diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml
index 566bf6012..6d71c394e 100644
--- a/deeplearning4j/deeplearning4j-modelimport/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml
@@ -105,6 +105,14 @@
${project.version}
test
+
+
+ org.nd4j
+ nd4j-tensorflow
+ ${nd4j.version}
+ test
+
+
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java
index 430b7407a..9b91d10cc 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java
@@ -103,4 +103,6 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration {
/* Keras weight initializers. */
private final String LAYER_FIELD_INIT = "kernel_initializer";
+
+ private final String TENSORFLOW_OP_LAYER = "TensorFlowOpLayer";
}
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java
new file mode 100644
index 000000000..2dd95338a
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java
@@ -0,0 +1,74 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.nn.modelimport.keras.layers;
+
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
+import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
+import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
+
+import java.util.Map;
+
+
+public class KerasTFOpLayer extends KerasLayer {
+
+ public KerasTFOpLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
+ super(kerasVersion);
+ if (kerasVersion != 2){
+ throw new UnsupportedKerasConfigurationException("KerasTFOpLayer expects Keras version 2");
+ }
+ }
+
+ /**
+ * Constructor from parsed Keras layer configuration dictionary.
+ *
+ * @param layerConfig dictionary containing Keras layer configuration
+ * @throws InvalidKerasConfigurationException Invalid Keras config
+ * @throws UnsupportedKerasConfigurationException Unsupported Keras config
+ */
+ public KerasTFOpLayer(Map layerConfig)
+ throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
+ this(layerConfig, true);
+ }
+
+ /**
+ * Constructor from parsed Keras layer configuration dictionary.
+ *
+ * @param layerConfig dictionary containing Keras layer configuration
+ * @param enforceTrainingConfig whether to enforce training-related configuration options
+ * @throws InvalidKerasConfigurationException Invalid Keras config
+ * @throws UnsupportedKerasConfigurationException Unsupported Keras config
+ */
+ public KerasTFOpLayer(Map layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException{
+ super(layerConfig, enforceTrainingConfig);
+ this.layer = new TFOpLayer((Map)((Map)layerConfig.get("config")).get("node_def"), (Map)((Map)layerConfig.get("config")).get("constants"));
+ }
+
+ /**
+ * Get layer output type.
+ *
+ * @param inputType Array of InputTypes
+ * @return output type as InputType
+ * @throws InvalidKerasConfigurationException Invalid Keras configuration
+ */
+ public InputType getOutputType(InputType... inputType){
+ return this.layer.getOutputType(0, inputType[0]);
+ }
+
+
+
+}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java
new file mode 100644
index 000000000..ecf64e8c0
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java
@@ -0,0 +1,106 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.nn.modelimport.keras.layers;
+
+import org.deeplearning4j.nn.api.ParamInitializer;
+import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.nn.conf.InputPreProcessor;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Layer;
+import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
+import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl;
+import org.deeplearning4j.nn.params.EmptyParamInitializer;
+import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.regularization.Regularization;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+
+public class TFOpLayer extends Layer {
+
+ private Map nodeDef;
+ private Map constants;
+ public TFOpLayer(Map nodeDef, Map constants){
+ super();
+ this.nodeDef = nodeDef;
+ this.constants = constants;
+ }
+
+ @Override
+ public ParamInitializer initializer() {
+ return EmptyParamInitializer.getInstance();
+ }
+ @Override
+ public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
+ return null;
+ }
+
+ @Override
+ public boolean isPretrainParam(String param){
+ return false;
+ }
+
+ @Override
+ public InputType getOutputType(int idx, InputType inputType){
+ long[] shape = inputType.getShape(true);
+ TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
+ long[] outputShape = tempLayer.getOutputShape(shape);
+ return InputType.inferInputType(Nd4j.create(outputShape));
+
+ }
+
+ @Override
+ public void setNIn(InputType inputType, boolean override){}
+
+
+ @Override
+ public GradientNormalization getGradientNormalization(){return null;}
+
+
+ @Override
+ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
+ Collection trainingListeners, int layerIndex, INDArray layerParamsView,
+ boolean initializeParams, DataType networkDataType) {
+
+ TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType);
+ tfOpLayerImpl.setListeners(trainingListeners);
+ tfOpLayerImpl.setIndex(layerIndex);
+ return tfOpLayerImpl;
+ }
+
+ @Override
+ public double getGradientNormalizationThreshold(){return 0.;}
+
+ @Override
+ public List getRegularizationByParam(String paramName){return null;}
+
+ @Override
+ public LayerMemoryReport getMemoryReport(InputType inputType) {
+ return new LayerMemoryReport(); //TODO
+ }
+
+
+
+
+
+}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java
new file mode 100644
index 000000000..d7b0b3b56
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java
@@ -0,0 +1,169 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.nn.modelimport.keras.layers;
+
+import lombok.Data;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.ArrayUtils;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.nn.layers.AbstractLayer;
+import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.TFGraphRunnerService;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.primitives.Pair;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.NodeDef;
+import com.google.gson.Gson;
+import org.nd4j.shade.protobuf.Message;
+import org.nd4j.shade.protobuf.TextFormat;
+
+import java.util.*;
+import java.util.List;
+
+
+@Slf4j
+@Data
+public class TFOpLayerImpl extends AbstractLayer {
+
+
+ private Map nodeDef;
+ private Map constants;
+ private List inputNames;
+ TFGraphRunnerService graphRunnerService;
+
+ public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){
+ super(conf, dtype);
+ this.nodeDef = nodeDef;
+ this.constants = constants;
+ setGraphRunner();
+ }
+
+ @Override
+ public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr){
+ throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet." +
+ " TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models " +
+ "(tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
+ }
+
+ /**
+ * Converts a Map representation of Nodedef to a singleton TF Graph and instantiates a GraphRunner.
+ */
+ private void setGraphRunner() {
+ try{
+ String json = new Gson().toJson(nodeDef);
+ NodeDef.Builder builder = NodeDef.newBuilder();
+ org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json, builder);
+ NodeDef nodeDef = builder.build();
+ List allInputNames = new ArrayList<>(); // including constants
+ Map inputDataTypes = new HashMap<>();
+ Map constArrays = new HashMap();
+ this.inputNames = new ArrayList<>();
+ List outputNames = Arrays.asList(nodeDef.getName());
+ Map attrMap = nodeDef.getAttrMap();
+ for (int i = 0; i < nodeDef.getInputCount(); i++){
+ String inputName = nodeDef.getInput(i);
+ String[] split = inputName.split("/");
+ String attrKey;
+ if (split.length == 1){
+ attrKey = "T";
+ }
+ else{
+ attrKey = "T" + split[split.length - 1];
+ }
+ allInputNames.add(nodeDef.getInput(i));
+ inputDataTypes.put(nodeDef.getInput(i), attrMap.get(attrKey).getType().toString());
+ if (constants.containsKey(String.valueOf(i))){
+ constArrays.put(nodeDef.getInput(i), Nd4j.create((List)constants.get(String.valueOf(i))));
+ }
+ else{
+ this.inputNames.add(nodeDef.getInput(i));
+ }
+ }
+ String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}";
+ for (int i = 0; i < allInputNames.size(); i++){
+ String inpName = allInputNames.get(i);
+ String dtype = inputDataTypes.get(inpName);
+ graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
+ }
+ log.info(graph);
+ GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
+ TextFormat.getParser().merge(graph, graphDefBuilder);
+ GraphDef graphDef = graphDefBuilder.build();
+ org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
+ byte[] graphBytes = serialized.toByteArray();
+
+ ServiceLoader sl = ServiceLoader.load(TFGraphRunnerService.class);
+ Iterator iter = sl.iterator();
+ if (!iter.hasNext()){
+ throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
+ }
+
+ this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes);
+ }
+ catch (Exception e){
+ throw new RuntimeException("Error parsing protobuf", e);
+ }
+
+ }
+
+ private INDArray runGraph(INDArray input){
+ if (input.rank() == 3){
+ // TODO make this a preprocessor
+ input = input.permute(0, 2, 1);
+ }
+ Map inputMap = new HashMap<>();
+ inputMap.put(inputNames.get(0), input);
+ INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
+ if (out.rank() == 3){
+ out = out.permute(0, 2, 1); // TODO post-processing?
+ }
+
+ return out;
+ }
+
+ public long[] getOutputShape(long[] inputShape){
+ long[] shape = ArrayUtils.clone(inputShape);
+ for(int i = 0; i < shape.length; i++){
+ if (shape[i] < 0){
+ shape[i] = 1;
+ }
+ }
+ INDArray dummyArr = Nd4j.zeros(shape);
+ return runGraph(dummyArr).shape();
+ }
+
+ @Override
+ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
+ return runGraph(input);
+ }
+
+
+ @Override
+ public boolean isPretrainLayer(){
+ return false;
+ }
+
+ @Override
+ public void clearNoiseWeightParams(){
+
+ }
+
+}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java
index 1428b6322..3f69cb7d4 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java
@@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
+import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
+import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
@@ -317,6 +319,11 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
+ } else if (conf instanceof Keras2LayerConfiguration){
+ Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
+ if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
+ layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
+ }
}
if (layer == null){
Class extends KerasLayer> customConfig = customLayers.get(layerClassName);
@@ -402,6 +409,16 @@ public class KerasLayerUtils {
public static String getLayerNameFromConfig(Map layerConfig,
KerasLayerConfiguration conf)
throws InvalidKerasConfigurationException {
+ if(conf instanceof Keras2LayerConfiguration){
+ Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
+ if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())){
+ if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
+ throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
+ + " missing from layer config");
+ return (String) layerConfig.get(conf.getLAYER_FIELD_NAME());
+ }
+ }
+
Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java
new file mode 100644
index 000000000..cb74b1ed1
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java
@@ -0,0 +1,50 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.nn.modelimport.keras;
+
+import org.deeplearning4j.BaseDL4JTest;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.junit.Assert;
+import org.junit.Test;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.resources.Resources;
+
+import java.io.File;
+import java.util.Arrays;
+
+public class TFKerasTests extends BaseDL4JTest{
+
+ @Test
+ public void testModelWithTFOp1() throws Exception{
+ File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
+ ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
+ INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
+ Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
+ }
+
+ @Test
+ public void testModelWithTFOp2() throws Exception{
+ File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
+ ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
+ INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
+ // dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
+ long[] expectedShape = new long[]{12 * 2, 5};
+ Assert.assertArrayEquals(expectedShape, out.shape());
+ }
+
+}
diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml
index e92372fc8..77acb2dc7 100644
--- a/deeplearning4j/deeplearning4j-nn/pom.xml
+++ b/deeplearning4j/deeplearning4j-nn/pom.xml
@@ -77,7 +77,11 @@
nd4j-common
${nd4j.version}
-
+
+ com.google.code.gson
+ gson
+ ${gson.version}
+
org.nd4j
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java
index 750bca77d..ad8590b0b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java
@@ -62,6 +62,7 @@ public abstract class AbstractLayer inputNames,
+ List outputNames,
+ byte[] graphBytes,
+ Map constants,
+ Map inputDataTypes
+ );
+
+ Map run(Map inputs);
+}
diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java
index 9cb0a609b..49861e3fe 100644
--- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java
+++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java
@@ -16,18 +16,16 @@
package org.nd4j.tensorflow.conversion.graphrunner;
-import lombok.Builder;
-import lombok.Singular;
+import lombok.*;
import org.apache.commons.io.FileUtils;
import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.util.JsonFormat;
-import lombok.Getter;
-import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.tensorflow.conversion.TensorDataType;
import org.apache.commons.io.IOUtils;
@@ -56,6 +54,7 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
* @author Adam Gibson
*/
@Slf4j
+@NoArgsConstructor
public class GraphRunner implements Closeable {
private static boolean isTfWarmedUp = false;
@@ -103,6 +102,9 @@ public class GraphRunner implements Closeable {
* @param inputDataTypes the expected input data types
* @param outputDataTypes the expected output data types
*/
+
+
+
@Builder
public GraphRunner(List inputNames,
List outputNames,
@@ -440,6 +442,7 @@ public class GraphRunner implements Closeable {
* @return a map of the output names to the
* ndarrays matching each output specified in the graph
*/
+
public Map run(Map inputs) {
if (!isTfWarmedUp && !isTfWarmingUp){
isTfWarmingUp = true;
@@ -683,4 +686,7 @@ public class GraphRunner implements Closeable {
return builder1.build();
}
+
+
+
}
diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java
new file mode 100644
index 000000000..7459a40ea
--- /dev/null
+++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java
@@ -0,0 +1,52 @@
+package org.nd4j.tensorflow.conversion.graphrunner;
+
+import org.nd4j.TFGraphRunnerService;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.tensorflow.conversion.TensorDataType;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class GraphRunnerServiceProvider implements TFGraphRunnerService {
+
+ private GraphRunner graphRunner;
+ Map inputs;
+
+ @Override
+ public TFGraphRunnerService init(
+ List inputNames,
+ List outputNames,
+ byte[] graphBytes,
+ Map constants,
+ Map inputDataTypes){
+ if (inputNames.size() != inputDataTypes.size()){
+ throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()");
+ }
+ Map convertedDataTypes = new HashMap<>();
+ for (int i = 0; i < inputNames.size(); i++){
+ convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i))));
+ }
+ Map castConstants = new HashMap<>();
+ for (Map.Entry e: constants.entrySet()) {
+ DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey())));
+ castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype));
+ }
+ this.inputs = castConstants;
+ graphRunner = GraphRunner.builder().inputNames(inputNames)
+ .outputNames(outputNames).graphBytes(graphBytes)
+ .inputDataTypes(convertedDataTypes).build();
+ return this;
+
+ }
+
+ @Override
+ public Map run(Map inputs){
+ if (graphRunner == null){
+ throw new RuntimeException("GraphRunner not initialized.");
+ }
+ this.inputs.putAll(inputs);
+ return graphRunner.run(this.inputs);
+ }
+}
diff --git a/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService b/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService
new file mode 100644
index 000000000..1b038ee6c
--- /dev/null
+++ b/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService
@@ -0,0 +1,17 @@
+ ################################################################################
+ # Copyright (c) 2020 Konduit K.K..
+ #
+ # This program and the accompanying materials are made available under the
+ # terms of the Apache License, Version 2.0 which is available at
+ # https://www.apache.org/licenses/LICENSE-2.0.
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ # License for the specific language governing permissions and limitations
+ # under the License.
+ #
+ # SPDX-License-Identifier: Apache-2.0
+ ################################################################################
+
+org.nd4j.tensorflow.conversion.graphrunner.GraphRunnerServiceProvider