From 95100ffd8cb80a0c1c4896213a2450bc3446096d Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 17 Aug 2019 14:13:31 +1000 Subject: [PATCH] Small build fixes (#127) * Small build fixes Signed-off-by: Alex Black * Fix RL4J Signed-off-by: Alex Black * Test fixes Signed-off-by: Alex Black * Another fix Signed-off-by: Alex Black --- .../embeddings/word2vec/Word2VecPerformer.java | 11 +++-------- .../word2vec/Word2VecPerformerVoid.java | 12 +++--------- .../training/SharedTrainingWorker.java | 4 ++-- .../rl4j/learning/async/AsyncLearning.java | 6 ++---- .../rl4j/learning/async/AsyncThread.java | 7 ++++++- .../rl4j/learning/async/AsyncThreadDiscrete.java | 4 ++-- .../learning/async/a3c/discrete/A3CDiscrete.java | 6 +++--- .../async/a3c/discrete/A3CDiscreteConv.java | 4 ++-- .../async/a3c/discrete/A3CThreadDiscrete.java | 4 ++-- .../discrete/AsyncNStepQLearningDiscrete.java | 6 +++--- .../discrete/AsyncNStepQLearningDiscreteConv.java | 4 ++-- .../AsyncNStepQLearningThreadDiscrete.java | 4 ++-- .../rl4j/learning/sync/ExpReplay.java | 15 ++++++++++++++- .../rl4j/learning/sync/qlearning/QLearning.java | 2 +- .../rl4j/learning/async/AsyncThreadTest.java | 2 +- .../qlearning/discrete/QLearningDiscreteTest.java | 6 ++++++ 16 files changed, 54 insertions(+), 43 deletions(-) diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java index 2ae1c6f23..45eca7327 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java @@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction, Ato initExpTable(); if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) { - try { - ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes()); - DataInputStream dis = new DataInputStream(bis); - table = Nd4j.read(dis); - } catch (IOException e) { - e.printStackTrace(); - } - + ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes()); + DataInputStream dis = new DataInputStream(bis); + table = Nd4j.read(dis); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java index 4d182b90f..539755ee6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java @@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction, initExpTable(); if (negative > 0 && conf.contains(TABLE)) { - try { - ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes()); - DataInputStream dis = new DataInputStream(bis); - table = Nd4j.read(dis); - } catch (IOException e) { - e.printStackTrace(); - } - + ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes()); + DataInputStream dis = new DataInputStream(bis); + table = Nd4j.read(dis); } - } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java index cbca7d52b..8cadbea43 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java @@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorker getAsyncGlobal(); @@ -60,9 +60,7 @@ public abstract class AsyncLearning asyncGlobal, int threadNumber) { + public AsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, int deviceNum) { this.threadNumber = threadNumber; + this.deviceNum = deviceNum; } public void setHistoryProcessor(IHistoryProcessor.Configuration conf) { @@ -87,6 +91,7 @@ public abstract class AsyncThread asyncGlobal, int threadNumber) { - super(asyncGlobal, threadNumber); + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, int threadNumber, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 5777e2394..7dbec6210 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -62,9 +62,9 @@ public abstract class A3CDiscrete extends AsyncLearning extends A3CDiscrete { } @Override - public AsyncThread newThread(int i) { - AsyncThread at = super.newThread(i); + public AsyncThread newThread(int i, int deviceNum) { + AsyncThread at = super.newThread(i, deviceNum); at.setHistoryProcessor(hpconf); return at; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 4c5873b11..3a481b09c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -57,8 +57,8 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random random; public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) { - super(asyncGlobal, threadNumber); + A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index f0d7a3349..bab60fec4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -55,9 +55,9 @@ public abstract class AsyncNStepQLearningDiscrete mdp.getActionSpace().setSeed(conf.getSeed()); } - - public AsyncThread newThread(int i) { - return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager); + @Override + public AsyncThread newThread(int i, int deviceNum) { + return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum); } public IDQN getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 4da14012e..257e5fb5d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -53,8 +53,8 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN } @Override - public AsyncThread newThread(int i) { - AsyncThread at = super.newThread(i); + public AsyncThread newThread(int i, int deviceNum) { + AsyncThread at = super.newThread(i, deviceNum); at.setHistoryProcessor(hpconf); return at; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 4f6c3ad09..23d6f79ca 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -56,8 +56,8 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, - IDataManager dataManager) { - super(asyncGlobal, threadNumber); + IDataManager dataManager, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); this.conf = conf; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index ead3f4d00..fd7d9465b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -16,6 +16,8 @@ package org.deeplearning4j.rl4j.learning.sync; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.queue.CircularFifoQueue; @@ -50,7 +52,18 @@ public class ExpReplay implements IExpReplay { ArrayList> batch = new ArrayList<>(size); int storageSize = storage.size(); int actualBatchSize = Math.min(storageSize, size); - int[] actualIndex = ThreadLocalRandom.current().ints(0, storageSize).distinct().limit(actualBatchSize).toArray(); + + int[] actualIndex = new int[actualBatchSize]; + ThreadLocalRandom r = ThreadLocalRandom.current(); + IntSet set = new IntOpenHashSet(); + for( int i=0; i trans = storage.get(actualIndex[i]); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index a2c25a43c..525995455 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -50,7 +50,7 @@ public abstract class QLearning expReplay; @Getter @Setter(AccessLevel.PACKAGE) - private IExpReplay expReplay; + protected IExpReplay expReplay; public QLearning(QLConfiguration conf) { super(conf); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 0e53103ef..23be44f01 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -194,7 +194,7 @@ public class AsyncThreadTest { private final IDataManager dataManager; public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) { - super(asyncGlobal, threadNumber); + super(asyncGlobal, threadNumber, 0); this.asyncGlobal = asyncGlobal; this.neuralNet = neuralNet; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 51bdeaf41..5762875aa 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -1,6 +1,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.mdp.MDP; @@ -138,5 +139,10 @@ public class QLearningDiscreteTest { protected Pair setTarget(ArrayList> transitions) { return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); } + + public void setExpReplay(IExpReplay exp){ + this.expReplay = exp; + } + } }