diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index b8bee4a79..6ed7819df 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -36,6 +36,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.Ignore; import org.junit.Test; @@ -242,7 +243,10 @@ public class DataSetIteratorTest extends BaseDL4JTest { MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); - model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); + //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); + + CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq); + model.setListeners(listener); model.fit(cifar); @@ -254,6 +258,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { eval.eval(testDS.getLabels(), output); } System.out.println(eval.stats(true)); + listener.exportScores(System.out); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java index cc3ec16fb..f6c2e269c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java @@ -25,6 +25,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -37,7 +38,83 @@ public class CollectScoresIterationListener extends BaseTrainingListener { private int frequency; private int iterationCount = 0; - private List> scoreVsIter = new ArrayList<>(); + //private List> scoreVsIter = new ArrayList<>(); + + public static class ScoreStat { + public static final int BUCKET_LENGTH = 10000; + + private int position = 0; + private int bucketNumber = 1; + private List indexes; + private List scores; + + public ScoreStat() { + indexes = new ArrayList<>(1); + indexes.add(new long[BUCKET_LENGTH]); + scores = new ArrayList<>(1); + scores.add(new double[BUCKET_LENGTH]); + } + + public List getIndexes() { + return indexes; + } + + public List getScores() { + return scores; + } + + public long[] getEffectiveIndexes() { + return Arrays.copyOfRange(indexes.get(0), 0, position); + } + + public double[] getEffectiveScores() { + return Arrays.copyOfRange(scores.get(0), 0, position); + } + + + /* + Originally scores array is initialized with BUCKET_LENGTH size. + When data doesn't fit there - arrays size is increased for BUCKET_LENGTH, + old data is copied and bucketNumber (counter of reallocations) being incremented. + + If we got more score points than MAX_VALUE - they are put to another item of scores list. + */ + private void reallocateGuard() { + if (position >= BUCKET_LENGTH * bucketNumber) { + + long fullLength = (long)BUCKET_LENGTH * bucketNumber; + + if (position == Integer.MAX_VALUE || fullLength >= Integer.MAX_VALUE) { + position = 0; + long[] newIndexes = new long[BUCKET_LENGTH]; + double[] newScores = new double[BUCKET_LENGTH]; + indexes.add(newIndexes); + scores.add(newScores); + } + else { + long[] newIndexes = new long[(int)fullLength + BUCKET_LENGTH]; + double[] newScores = new double[(int)fullLength + BUCKET_LENGTH]; + System.arraycopy(indexes.get(indexes.size()-1), 0, newIndexes, 0, (int)fullLength); + System.arraycopy(scores.get(scores.size()-1), 0, newScores, 0, (int)fullLength); + scores.remove(scores.size()-1); + indexes.remove(indexes.size()-1); + int lastIndex = scores.size() == 0 ? 0 : scores.size()-1; + scores.add(lastIndex, newScores); + indexes.add(lastIndex, newIndexes); + } + bucketNumber += 1; + } + } + + public void addScore(long index, double score) { + reallocateGuard(); + scores.get(scores.size() - 1)[position] = score; + indexes.get(scores.size() - 1)[position] = index; + position += 1; + } + } + + ScoreStat scoreVsIter = new ScoreStat(); /** * Constructor for collecting scores with default saving frequency of 1 @@ -60,11 +137,12 @@ public class CollectScoresIterationListener extends BaseTrainingListener { public void iterationDone(Model model, int iteration, int epoch) { if (++iterationCount % frequency == 0) { double score = model.score(); - scoreVsIter.add(new Pair<>(iterationCount, score)); + scoreVsIter.reallocateGuard(); + scoreVsIter.addScore(iteration, score); } } - public List> getScoreVsIter() { + public ScoreStat getScoreVsIter() { return scoreVsIter; } @@ -84,8 +162,16 @@ public class CollectScoresIterationListener extends BaseTrainingListener { public void exportScores(OutputStream outputStream, String delimiter) throws IOException { StringBuilder sb = new StringBuilder(); sb.append("Iteration").append(delimiter).append("Score"); - for (Pair p : scoreVsIter) { - sb.append("\n").append(p.getFirst()).append(delimiter).append(p.getSecond()); + int largeBuckets = scoreVsIter.indexes.size(); + for (int j = 0; j < largeBuckets; ++j) { + long[] indexes = scoreVsIter.indexes.get(j); + double[] scores = scoreVsIter.scores.get(j); + + int effectiveLength = (j < largeBuckets -1) ? indexes.length : scoreVsIter.position; + + for (int i = 0; i < effectiveLength; ++i) { + sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]); + } } outputStream.write(sb.toString().getBytes("UTF-8")); } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java new file mode 100644 index 000000000..5a3588e25 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java @@ -0,0 +1,98 @@ +package org.deeplearning4j.optimize.listeners; + +import org.junit.Ignore; +import org.junit.Test; + +import java.util.List; +import static org.junit.Assert.*; + +public class ScoreStatTest { + @Test + public void testScoreStatSmall() { + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + for (int i = 0; i < CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH; ++i) { + double score = (double)i; + statTest.addScore(i, score); + } + + List indexes = statTest.getIndexes(); + List scores = statTest.getScores(); + + assertTrue(indexes.size() == 1); + assertTrue(scores.size() == 1); + + assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); + assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); + assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1); + assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4); + } + + @Test + public void testScoreStatAverage() { + int dataSize = 1000000; + long[] indexes = new long[dataSize]; + double[] scores = new double[dataSize]; + + for (int i = 0; i < dataSize; ++i) { + indexes[i] = i; + scores[i] = i; + } + + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + for (int i = 0; i < dataSize; ++i) { + statTest.addScore(indexes[i], scores[i]); + } + + long[] indexesStored = statTest.getIndexes().get(0); + double[] scoresStored = statTest.getScores().get(0); + + assertArrayEquals(indexes, indexesStored); + assertArrayEquals(scores, scoresStored, 1e-4); + } + + @Test + public void testScoresClean() { + int dataSize = 10256; // expected to be placed in 2 buckets of 10k elements size + long[] indexes = new long[dataSize]; + double[] scores = new double[dataSize]; + + for (int i = 0; i < dataSize; ++i) { + indexes[i] = i; + scores[i] = i; + } + + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + for (int i = 0; i < dataSize; ++i) { + statTest.addScore(indexes[i], scores[i]); + } + + long[] indexesEffective = statTest.getEffectiveIndexes(); + double[] scoresEffective = statTest.getEffectiveScores(); + + assertArrayEquals(indexes, indexesEffective); + assertArrayEquals(scores, scoresEffective, 1e-4); + } + + @Ignore + @Test + public void testScoreStatBig() { + CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); + long bigLength = (long)Integer.MAX_VALUE + 5; + for (long i = 0; i < bigLength; ++i) { + double score = (double)i; + statTest.addScore(i, score); + } + + List indexes = statTest.getIndexes(); + List scores = statTest.getScores(); + + assertTrue(indexes.size() == 2); + assertTrue(scores.size() == 2); + + for (int i = 0; i < 5; ++i) { + assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i); + assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i); + + } + } +}