diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index d9da12b62..6831af10b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.dtypes; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer; import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; @@ -128,7 +129,7 @@ public class DTypeTests extends BaseDL4JTest { throw new RuntimeException(e); } - if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) { + if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype continue; } diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 566bf6012..6d71c394e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -105,6 +105,14 @@ ${project.version} test + + + org.nd4j + nd4j-tensorflow + ${nd4j.version} + test + + diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java index 430b7407a..9b91d10cc 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/Keras2LayerConfiguration.java @@ -103,4 +103,6 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration { /* Keras weight initializers. */ private final String LAYER_FIELD_INIT = "kernel_initializer"; + + private final String TENSORFLOW_OP_LAYER = "TensorFlowOpLayer"; } \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java new file mode 100644 index 000000000..2dd95338a --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasTFOpLayer.java @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers; + +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.modelimport.keras.KerasLayer; +import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; +import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; + +import java.util.Map; + + +public class KerasTFOpLayer extends KerasLayer { + + public KerasTFOpLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException { + super(kerasVersion); + if (kerasVersion != 2){ + throw new UnsupportedKerasConfigurationException("KerasTFOpLayer expects Keras version 2"); + } + } + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Unsupported Keras config + */ + public KerasTFOpLayer(Map layerConfig) + throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { + this(layerConfig, true); + } + + /** + * Constructor from parsed Keras layer configuration dictionary. + * + * @param layerConfig dictionary containing Keras layer configuration + * @param enforceTrainingConfig whether to enforce training-related configuration options + * @throws InvalidKerasConfigurationException Invalid Keras config + * @throws UnsupportedKerasConfigurationException Unsupported Keras config + */ + public KerasTFOpLayer(Map layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException{ + super(layerConfig, enforceTrainingConfig); + this.layer = new TFOpLayer((Map)((Map)layerConfig.get("config")).get("node_def"), (Map)((Map)layerConfig.get("config")).get("constants")); + } + + /** + * Get layer output type. + * + * @param inputType Array of InputTypes + * @return output type as InputType + * @throws InvalidKerasConfigurationException Invalid Keras configuration + */ + public InputType getOutputType(InputType... inputType){ + return this.layer.getOutputType(0, inputType[0]); + } + + + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java new file mode 100644 index 000000000..ecf64e8c0 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers; + +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl; +import org.deeplearning4j.nn.params.EmptyParamInitializer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.regularization.Regularization; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + + +public class TFOpLayer extends Layer { + + private Map nodeDef; + private Map constants; + public TFOpLayer(Map nodeDef, Map constants){ + super(); + this.nodeDef = nodeDef; + this.constants = constants; + } + + @Override + public ParamInitializer initializer() { + return EmptyParamInitializer.getInstance(); + } + @Override + public InputPreProcessor getPreProcessorForInputType(InputType inputType) { + return null; + } + + @Override + public boolean isPretrainParam(String param){ + return false; + } + + @Override + public InputType getOutputType(int idx, InputType inputType){ + long[] shape = inputType.getShape(true); + TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null); + long[] outputShape = tempLayer.getOutputShape(shape); + return InputType.inferInputType(Nd4j.create(outputShape)); + + } + + @Override + public void setNIn(InputType inputType, boolean override){} + + + @Override + public GradientNormalization getGradientNormalization(){return null;} + + + @Override + public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + + TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType); + tfOpLayerImpl.setListeners(trainingListeners); + tfOpLayerImpl.setIndex(layerIndex); + return tfOpLayerImpl; + } + + @Override + public double getGradientNormalizationThreshold(){return 0.;} + + @Override + public List getRegularizationByParam(String paramName){return null;} + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + return new LayerMemoryReport(); //TODO + } + + + + + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java new file mode 100644 index 000000000..d7b0b3b56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -0,0 +1,169 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.layers; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.AbstractLayer; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.TFGraphRunnerService; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; +import com.google.gson.Gson; +import org.nd4j.shade.protobuf.Message; +import org.nd4j.shade.protobuf.TextFormat; + +import java.util.*; +import java.util.List; + + +@Slf4j +@Data +public class TFOpLayerImpl extends AbstractLayer { + + + private Map nodeDef; + private Map constants; + private List inputNames; + TFGraphRunnerService graphRunnerService; + + public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){ + super(conf, dtype); + this.nodeDef = nodeDef; + this.constants = constants; + setGraphRunner(); + } + + @Override + public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr){ + throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet." + + " TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models " + + "(tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers."); + } + + /** + * Converts a Map representation of Nodedef to a singleton TF Graph and instantiates a GraphRunner. + */ + private void setGraphRunner() { + try{ + String json = new Gson().toJson(nodeDef); + NodeDef.Builder builder = NodeDef.newBuilder(); + org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json, builder); + NodeDef nodeDef = builder.build(); + List allInputNames = new ArrayList<>(); // including constants + Map inputDataTypes = new HashMap<>(); + Map constArrays = new HashMap(); + this.inputNames = new ArrayList<>(); + List outputNames = Arrays.asList(nodeDef.getName()); + Map attrMap = nodeDef.getAttrMap(); + for (int i = 0; i < nodeDef.getInputCount(); i++){ + String inputName = nodeDef.getInput(i); + String[] split = inputName.split("/"); + String attrKey; + if (split.length == 1){ + attrKey = "T"; + } + else{ + attrKey = "T" + split[split.length - 1]; + } + allInputNames.add(nodeDef.getInput(i)); + inputDataTypes.put(nodeDef.getInput(i), attrMap.get(attrKey).getType().toString()); + if (constants.containsKey(String.valueOf(i))){ + constArrays.put(nodeDef.getInput(i), Nd4j.create((List)constants.get(String.valueOf(i)))); + } + else{ + this.inputNames.add(nodeDef.getInput(i)); + } + } + String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}"; + for (int i = 0; i < allInputNames.size(); i++){ + String inpName = allInputNames.get(i); + String dtype = inputDataTypes.get(inpName); + graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph; + } + log.info(graph); + GraphDef.Builder graphDefBuilder = GraphDef.newBuilder(); + TextFormat.getParser().merge(graph, graphDefBuilder); + GraphDef graphDef = graphDefBuilder.build(); + org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString(); + byte[] graphBytes = serialized.toByteArray(); + + ServiceLoader sl = ServiceLoader.load(TFGraphRunnerService.class); + Iterator iter = sl.iterator(); + if (!iter.hasNext()){ + throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute."); + } + + this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes); + } + catch (Exception e){ + throw new RuntimeException("Error parsing protobuf", e); + } + + } + + private INDArray runGraph(INDArray input){ + if (input.rank() == 3){ + // TODO make this a preprocessor + input = input.permute(0, 2, 1); + } + Map inputMap = new HashMap<>(); + inputMap.put(inputNames.get(0), input); + INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0]; + if (out.rank() == 3){ + out = out.permute(0, 2, 1); // TODO post-processing? + } + + return out; + } + + public long[] getOutputShape(long[] inputShape){ + long[] shape = ArrayUtils.clone(inputShape); + for(int i = 0; i < shape.length; i++){ + if (shape[i] < 0){ + shape[i] = 1; + } + } + INDArray dummyArr = Nd4j.zeros(shape); + return runGraph(dummyArr).shape(); + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ + return runGraph(input); + } + + + @Override + public boolean isPretrainLayer(){ + return false; + } + + @Override + public void clearNoiseWeightParams(){ + + } + +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 1428b6322..3f69cb7d4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; +import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput; +import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*; @@ -317,6 +319,11 @@ public class KerasLayerUtils { layer = new KerasELU(layerConfig, enforceTrainingConfig); } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); + } else if (conf instanceof Keras2LayerConfiguration){ + Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; + if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){ + layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig); + } } if (layer == null){ Class customConfig = customLayers.get(layerClassName); @@ -402,6 +409,16 @@ public class KerasLayerUtils { public static String getLayerNameFromConfig(Map layerConfig, KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { + if(conf instanceof Keras2LayerConfiguration){ + Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; + if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())){ + if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME())) + throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME() + + " missing from layer config"); + return (String) layerConfig.get(conf.getLAYER_FIELD_NAME()); + } + } + Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME())) throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME() diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java new file mode 100644 index 000000000..cb74b1ed1 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras; + +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.Arrays; + +public class TFKerasTests extends BaseDL4JTest{ + + @Test + public void testModelWithTFOp1() throws Exception{ + File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5"); + ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); + INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); + Assert.assertArrayEquals(new long[]{12, 3}, out.shape()); + } + + @Test + public void testModelWithTFOp2() throws Exception{ + File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5"); + ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); + INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); + // dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed + long[] expectedShape = new long[]{12 * 2, 5}; + Assert.assertArrayEquals(expectedShape, out.shape()); + } + +} diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml index e92372fc8..77acb2dc7 100644 --- a/deeplearning4j/deeplearning4j-nn/pom.xml +++ b/deeplearning4j/deeplearning4j-nn/pom.xml @@ -77,7 +77,11 @@ nd4j-common ${nd4j.version} - + + com.google.code.gson + gson + ${gson.version} + org.nd4j diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index 750bca77d..ad8590b0b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -62,6 +62,7 @@ public abstract class AbstractLayer inputNames, + List outputNames, + byte[] graphBytes, + Map constants, + Map inputDataTypes + ); + + Map run(Map inputs); +} diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java index 9cb0a609b..49861e3fe 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java @@ -16,18 +16,16 @@ package org.nd4j.tensorflow.conversion.graphrunner; -import lombok.Builder; -import lombok.Singular; +import lombok.*; import org.apache.commons.io.FileUtils; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.primitives.Pair; import org.nd4j.shade.protobuf.ByteString; import org.nd4j.shade.protobuf.InvalidProtocolBufferException; import org.nd4j.shade.protobuf.util.JsonFormat; -import lombok.Getter; -import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.nd4j.tensorflow.conversion.TensorDataType; import org.apache.commons.io.IOUtils; @@ -56,6 +54,7 @@ import static org.bytedeco.tensorflow.global.tensorflow.*; * @author Adam Gibson */ @Slf4j +@NoArgsConstructor public class GraphRunner implements Closeable { private static boolean isTfWarmedUp = false; @@ -103,6 +102,9 @@ public class GraphRunner implements Closeable { * @param inputDataTypes the expected input data types * @param outputDataTypes the expected output data types */ + + + @Builder public GraphRunner(List inputNames, List outputNames, @@ -440,6 +442,7 @@ public class GraphRunner implements Closeable { * @return a map of the output names to the * ndarrays matching each output specified in the graph */ + public Map run(Map inputs) { if (!isTfWarmedUp && !isTfWarmingUp){ isTfWarmingUp = true; @@ -683,4 +686,7 @@ public class GraphRunner implements Closeable { return builder1.build(); } + + + } diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java new file mode 100644 index 000000000..7459a40ea --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunnerServiceProvider.java @@ -0,0 +1,52 @@ +package org.nd4j.tensorflow.conversion.graphrunner; + +import org.nd4j.TFGraphRunnerService; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.tensorflow.conversion.TensorDataType; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class GraphRunnerServiceProvider implements TFGraphRunnerService { + + private GraphRunner graphRunner; + Map inputs; + + @Override + public TFGraphRunnerService init( + List inputNames, + List outputNames, + byte[] graphBytes, + Map constants, + Map inputDataTypes){ + if (inputNames.size() != inputDataTypes.size()){ + throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()"); + } + Map convertedDataTypes = new HashMap<>(); + for (int i = 0; i < inputNames.size(); i++){ + convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i)))); + } + Map castConstants = new HashMap<>(); + for (Map.Entry e: constants.entrySet()) { + DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey()))); + castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype)); + } + this.inputs = castConstants; + graphRunner = GraphRunner.builder().inputNames(inputNames) + .outputNames(outputNames).graphBytes(graphBytes) + .inputDataTypes(convertedDataTypes).build(); + return this; + + } + + @Override + public Map run(Map inputs){ + if (graphRunner == null){ + throw new RuntimeException("GraphRunner not initialized."); + } + this.inputs.putAll(inputs); + return graphRunner.run(this.inputs); + } +} diff --git a/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService b/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService new file mode 100644 index 000000000..1b038ee6c --- /dev/null +++ b/nd4j/nd4j-tensorflow/src/main/resources/META-INF/services/org.nd4j.TFGraphRunnerService @@ -0,0 +1,17 @@ + ################################################################################ + # Copyright (c) 2020 Konduit K.K.. + # + # This program and the accompanying materials are made available under the + # terms of the Apache License, Version 2.0 which is available at + # https://www.apache.org/licenses/LICENSE-2.0. + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + # License for the specific language governing permissions and limitations + # under the License. + # + # SPDX-License-Identifier: Apache-2.0 + ################################################################################ + +org.nd4j.tensorflow.conversion.graphrunner.GraphRunnerServiceProvider