diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
index 70945b06c..aa0cab9a0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java
@@ -43,9 +43,12 @@ import org.nd4j.linalg.primitives.Pair;
import java.io.*;
import java.util.ArrayList;
import java.util.Enumeration;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
+import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
/**
@@ -215,7 +218,24 @@ public class ModelSerializer {
*/
public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean loadUpdater)
throws IOException {
- ZipFile zipFile = new ZipFile(file);
+ return restoreMultiLayerNetwork(new FileInputStream(file), loadUpdater);
+ }
+
+
+ /**
+ * Load a MultiLayerNetwork from InputStream from an input stream
+ * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
+ *
+ * @param is the inputstream to load from
+ * @return the loaded multi layer network
+ * @throws IOException
+ * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
+ */
+ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
+ throws IOException {
+ checkInputStream(is);
+
+ Map zipFile = loadZipData(is);
boolean gotConfig = false;
boolean gotCoefficients = false;
@@ -229,11 +249,11 @@ public class ModelSerializer {
DataSetPreProcessor preProcessor = null;
- ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
+ byte[] config = zipFile.get(CONFIGURATION_JSON);
if (config != null) {
//restoring configuration
- InputStream stream = zipFile.getInputStream(config);
+ InputStream stream = new ByteArrayInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
@@ -248,25 +268,25 @@ public class ModelSerializer {
}
- ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
+ byte[] coefficients = zipFile.get(COEFFICIENTS_BIN);
if (coefficients != null ) {
- if(coefficients.getSize() > 0) {
- InputStream stream = zipFile.getInputStream(coefficients);
+ if(coefficients.length > 0) {
+ InputStream stream = new ByteArrayInputStream(coefficients);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
} else {
- ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
+ byte[] noParamsMarker = zipFile.get(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null);
}
}
if (loadUpdater) {
- ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
+ byte[] updaterStateEntry = zipFile.get(UPDATER_BIN);
if (updaterStateEntry != null) {
- InputStream stream = zipFile.getInputStream(updaterStateEntry);
+ InputStream stream = new ByteArrayInputStream(updaterStateEntry);
DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
updaterState = Nd4j.read(dis);
@@ -275,9 +295,9 @@ public class ModelSerializer {
}
}
- ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
+ byte[] prep = zipFile.get(PREPROCESSOR_BIN);
if (prep != null) {
- InputStream stream = zipFile.getInputStream(prep);
+ InputStream stream = new ByteArrayInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
@@ -290,7 +310,6 @@ public class ModelSerializer {
}
- zipFile.close();
if (gotConfig && gotCoefficients) {
MultiLayerConfiguration confFromJson;
@@ -328,31 +347,6 @@ public class ModelSerializer {
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
-
- /**
- * Load a MultiLayerNetwork from InputStream from an input stream
- * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- *
- * @param is the inputstream to load from
- * @return the loaded multi layer network
- * @throws IOException
- * @see #restoreMultiLayerNetworkAndNormalizer(InputStream, boolean)
- */
- public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater)
- throws IOException {
- checkInputStream(is);
-
- File tmpFile = null;
- try{
- tmpFile = tempFileFromStream(is);
- return restoreMultiLayerNetwork(tmpFile, loadUpdater);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
- }
- }
- }
-
/**
* Restore a multi layer network from an input stream
* * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
@@ -404,15 +398,9 @@ public class ModelSerializer {
@NonNull InputStream is, boolean loadUpdater) throws IOException {
checkInputStream(is);
- File tmpFile = null;
- try {
- tmpFile = tempFileFromStream(is);
- return restoreMultiLayerNetworkAndNormalizer(tmpFile, loadUpdater);
- } finally {
- if (tmpFile != null) {
- tmpFile.delete();
- }
- }
+ MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater);
+ Normalizer norm = restoreNormalizerFromInputStream(is);
+ return new Pair<>(net, norm);
}
/**
@@ -425,9 +413,7 @@ public class ModelSerializer {
*/
public static Pair restoreMultiLayerNetworkAndNormalizer(@NonNull File file, boolean loadUpdater)
throws IOException {
- MultiLayerNetwork net = restoreMultiLayerNetwork(file, loadUpdater);
- Normalizer norm = restoreNormalizerFromFile(file);
- return new Pair<>(net, norm);
+ return restoreMultiLayerNetworkAndNormalizer(new FileInputStream(file), loadUpdater);
}
/**
@@ -465,87 +451,7 @@ public class ModelSerializer {
throws IOException {
checkInputStream(is);
- File tmpFile = null;
- try{
- tmpFile = tempFileFromStream(is);
- return restoreComputationGraph(tmpFile, loadUpdater);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
- }
- }
- }
-
- /**
- * Load a computation graph from a InputStream
- * @param is the inputstream to get the computation graph from
- * @return the loaded computation graph
- *
- * @throws IOException
- */
- public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException {
- return restoreComputationGraph(is, true);
- }
-
- /**
- * Load a computation graph from a file
- * @param file the file to get the computation graph from
- * @return the loaded computation graph
- *
- * @throws IOException
- */
- public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
- return restoreComputationGraph(file, true);
- }
-
- /**
- * Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream.
- * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
- *
- * @param is Input stream to read from
- * @param loadUpdater Whether to load the updater from the model or not
- * @return Model and normalizer, if present
- * @throws IOException If an error occurs when reading from the stream
- */
- public static Pair restoreComputationGraphAndNormalizer(
- @NonNull InputStream is, boolean loadUpdater) throws IOException {
- checkInputStream(is);
-
- File tmpFile = null;
- try {
- tmpFile = tempFileFromStream(is);
- return restoreComputationGraphAndNormalizer(tmpFile, loadUpdater);
- } finally {
- if (tmpFile != null) {
- tmpFile.delete();
- }
- }
- }
-
- /**
- * Restore a ComputationGraph and Normalizer (if present - null if not) from a File
- *
- * @param file File to read the model and normalizer from
- * @param loadUpdater Whether to load the updater from the model or not
- * @return Model and normalizer, if present
- * @throws IOException If an error occurs when reading from the File
- */
- public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
- throws IOException {
- ComputationGraph net = restoreComputationGraph(file, loadUpdater);
- Normalizer norm = restoreNormalizerFromFile(file);
- return new Pair<>(net, norm);
- }
-
- /**
- * Load a computation graph from a file
- * @param file the file to get the computation graph from
- * @return the loaded computation graph
- *
- * @throws IOException
- */
- public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
- ZipFile zipFile = new ZipFile(file);
+ Map files = loadZipData(is);
boolean gotConfig = false;
boolean gotCoefficients = false;
@@ -558,11 +464,11 @@ public class ModelSerializer {
DataSetPreProcessor preProcessor = null;
- ZipEntry config = zipFile.getEntry(CONFIGURATION_JSON);
+ byte[] config = files.get(CONFIGURATION_JSON);
if (config != null) {
//restoring configuration
- InputStream stream = zipFile.getInputStream(config);
+ InputStream stream = new ByteArrayInputStream(config);
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
StringBuilder js = new StringBuilder();
@@ -577,27 +483,27 @@ public class ModelSerializer {
}
- ZipEntry coefficients = zipFile.getEntry(COEFFICIENTS_BIN);
+ byte[] coefficients = files.get(COEFFICIENTS_BIN);
if (coefficients != null) {
- if(coefficients.getSize() > 0) {
- InputStream stream = zipFile.getInputStream(coefficients);
- DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
+ if(coefficients.length > 0) {
+ InputStream stream = new ByteArrayInputStream(coefficients);
+ DataInputStream dis = new DataInputStream(stream);
params = Nd4j.read(dis);
dis.close();
gotCoefficients = true;
} else {
- ZipEntry noParamsMarker = zipFile.getEntry(NO_PARAMS_MARKER);
+ byte[] noParamsMarker = files.get(NO_PARAMS_MARKER);
gotCoefficients = (noParamsMarker != null);
}
}
if (loadUpdater) {
- ZipEntry updaterStateEntry = zipFile.getEntry(UPDATER_BIN);
+ byte[] updaterStateEntry = files.get(UPDATER_BIN);
if (updaterStateEntry != null) {
- InputStream stream = zipFile.getInputStream(updaterStateEntry);
- DataInputStream dis = new DataInputStream(new BufferedInputStream(stream));
+ InputStream stream = new ByteArrayInputStream(updaterStateEntry);
+ DataInputStream dis = new DataInputStream(stream);
updaterState = Nd4j.read(dis);
dis.close();
@@ -605,9 +511,9 @@ public class ModelSerializer {
}
}
- ZipEntry prep = zipFile.getEntry(PREPROCESSOR_BIN);
+ byte[] prep = files.get(PREPROCESSOR_BIN);
if (prep != null) {
- InputStream stream = zipFile.getInputStream(prep);
+ InputStream stream = new ByteArrayInputStream(prep);
ObjectInputStream ois = new ObjectInputStream(stream);
try {
@@ -620,8 +526,6 @@ public class ModelSerializer {
}
- zipFile.close();
-
if (gotConfig && gotCoefficients) {
ComputationGraphConfiguration confFromJson;
try{
@@ -662,6 +566,70 @@ public class ModelSerializer {
+ "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]");
}
+ /**
+ * Load a computation graph from a InputStream
+ * @param is the inputstream to get the computation graph from
+ * @return the loaded computation graph
+ *
+ * @throws IOException
+ */
+ public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException {
+ return restoreComputationGraph(is, true);
+ }
+
+ /**
+ * Load a computation graph from a file
+ * @param file the file to get the computation graph from
+ * @return the loaded computation graph
+ *
+ * @throws IOException
+ */
+ public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
+ return restoreComputationGraph(file, true);
+ }
+
+ /**
+ * Restore a ComputationGraph and Normalizer (if present - null if not) from the InputStream.
+ * Note: the input stream is read fully and closed by this method. Consequently, the input stream cannot be re-used.
+ *
+ * @param is Input stream to read from
+ * @param loadUpdater Whether to load the updater from the model or not
+ * @return Model and normalizer, if present
+ * @throws IOException If an error occurs when reading from the stream
+ */
+ public static Pair restoreComputationGraphAndNormalizer(
+ @NonNull InputStream is, boolean loadUpdater) throws IOException {
+ checkInputStream(is);
+
+ ComputationGraph net = restoreComputationGraph(is, loadUpdater);
+ Normalizer norm = restoreNormalizerFromInputStream(is);
+ return new Pair<>(net, norm);
+ }
+
+ /**
+ * Restore a ComputationGraph and Normalizer (if present - null if not) from a File
+ *
+ * @param file File to read the model and normalizer from
+ * @param loadUpdater Whether to load the updater from the model or not
+ * @return Model and normalizer, if present
+ * @throws IOException If an error occurs when reading from the File
+ */
+ public static Pair restoreComputationGraphAndNormalizer(@NonNull File file, boolean loadUpdater)
+ throws IOException {
+ return restoreComputationGraphAndNormalizer(new FileInputStream(file), loadUpdater);
+ }
+
+ /**
+ * Load a computation graph from a file
+ * @param file the file to get the computation graph from
+ * @return the loaded computation graph
+ *
+ * @throws IOException
+ */
+ public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean loadUpdater) throws IOException {
+ return restoreComputationGraph(new FileInputStream(file), loadUpdater);
+ }
+
/**
*
* @param model
@@ -811,15 +779,16 @@ public class ModelSerializer {
}
//Add new object:
- ZipEntry entry = new ZipEntry("objects/" + key);
- writeFile.putNextEntry(entry);
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(o);
byte[] bytes = baos.toByteArray();
+ ZipEntry entry = new ZipEntry("objects/" + key);
+ entry.setSize(bytes.length);
+ writeFile.putNextEntry(entry);
writeFile.write(bytes);
+ writeFile.closeEntry();
}
- writeFile.closeEntry();
writeFile.close();
zipFile.close();
@@ -904,18 +873,12 @@ public class ModelSerializer {
* @param file
* @return
*/
- public static T restoreNormalizerFromFile(File file) {
- try (ZipFile zipFile = new ZipFile(file)) {
- ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
-
- // checking for file existence
- if (norm == null)
- return null;
-
- return NormalizerSerializer.getDefault().restore(zipFile.getInputStream(norm));
+ public static T restoreNormalizerFromFile(File file) throws IOException {
+ try {
+ return restoreNormalizerFromInputStream(new FileInputStream(file));
} catch (Exception e) {
log.warn("Error while restoring normalizer, trying to restore assuming deprecated format...");
- DataNormalization restoredDeprecated = restoreNormalizerFromFileDeprecated(file);
+ DataNormalization restoredDeprecated = restoreNormalizerFromInputStreamDeprecated(new FileInputStream(file));
log.warn("Recovered using deprecated method. Will now re-save the normalizer to fix this issue.");
addNormalizerToModel(file, restoredDeprecated);
@@ -934,15 +897,18 @@ public class ModelSerializer {
public static T restoreNormalizerFromInputStream(InputStream is) throws IOException {
checkInputStream(is);
- File tmpFile = null;
+ Map files = loadZipData(is);
+ byte[] norm = files.get(NORMALIZER_BIN);
+
+ // checking for file existence
+ if (norm == null)
+ return null;
try {
- tmpFile = tempFileFromStream(is);
- return restoreNormalizerFromFile(tmpFile);
- } finally {
- if(tmpFile != null){
- tmpFile.delete();
- }
+ return NormalizerSerializer.getDefault().restore(new ByteArrayInputStream(norm));
}
+ catch (Exception e) {
+ throw new IOException("Error loading normalizer", e);
+ }
}
/**
@@ -953,17 +919,9 @@ public class ModelSerializer {
* @param file
* @return
*/
- private static DataNormalization restoreNormalizerFromFileDeprecated(File file) {
- try (ZipFile zipFile = new ZipFile(file)) {
- ZipEntry norm = zipFile.getEntry(NORMALIZER_BIN);
-
- // checking for file existence
- if (norm == null)
- return null;
-
- InputStream stream = zipFile.getInputStream(norm);
+ private static DataNormalization restoreNormalizerFromInputStreamDeprecated(InputStream stream) {
+ try {
ObjectInputStream ois = new ObjectInputStream(stream);
-
try {
DataNormalization normalizer = (DataNormalization) ois.readObject();
return normalizer;
@@ -996,31 +954,30 @@ public class ModelSerializer {
*/
}
- private static void checkTempFileFromInputStream(File f) throws IOException {
- if (f.length() <= 0) {
- throw new IOException("Error reading from input stream: temporary file is empty after copying entire stream." +
- " Stream may have been closed before reading, is attempting to be used multiple times, or does not" +
- " point to a model file?");
- }
+ private static Map loadZipData(InputStream is) throws IOException {
+ Map result = new HashMap<>();
+ try (final ZipInputStream zis = new ZipInputStream(is)) {
+ while (true) {
+ final ZipEntry zipEntry = zis.getNextEntry();
+ if (zipEntry == null)
+ break;
+ if(zipEntry.isDirectory() || zipEntry.getSize() > Integer.MAX_VALUE)
+ throw new IllegalArgumentException();
+
+ final int size = (int) (zipEntry.getSize());
+ final byte[] data;
+ if (size >= 0) { // known size
+ data = IOUtils.readFully(zis, size);
+ }
+ else { // unknown size
+ final ByteArrayOutputStream bout = new ByteArrayOutputStream();
+ IOUtils.copy(zis, bout);
+ data = bout.toByteArray();
+ }
+ result.put(zipEntry.getName(), data);
+ }
+ }
+ return result;
}
- private static File tempFileFromStream(InputStream is) throws IOException{
- checkInputStream(is);
- String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY);
- File tmpFile = DL4JFileUtils.createTempFile("dl4jModelSerializer", "bin");
- try {
- tmpFile.deleteOnExit();
- BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmpFile));
- IOUtils.copy(is, bufferedOutputStream);
- bufferedOutputStream.flush();
- IOUtils.closeQuietly(bufferedOutputStream);
- checkTempFileFromInputStream(tmpFile);
- return tmpFile;
- } catch (IOException e){
- if(tmpFile != null){
- tmpFile.delete();
- }
- throw e;
- }
- }
}