From 171ce51f46137bfe7a678ee03eff4346c89ded56 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Tue, 15 Oct 2019 21:56:24 -0400 Subject: [PATCH] RL4J: Use Nd4j Random instead of java.util.Random (#8282) * Changed to use Nd4j Random instead of java.util.Random Signed-off-by: unknown * Changed to use Nd4j.getRandom() instead of the factory Signed-off-by: Alexandre Boulanger --- .../rl4j/learning/ILearning.java | 2 +- .../rl4j/learning/Learning.java | 9 +- .../learning/async/AsyncConfiguration.java | 2 +- .../rl4j/learning/async/AsyncLearning.java | 4 - .../async/a3c/discrete/A3CDiscrete.java | 16 +- .../async/a3c/discrete/A3CThreadDiscrete.java | 15 +- .../discrete/AsyncNStepQLearningDiscrete.java | 8 +- .../AsyncNStepQLearningThreadDiscrete.java | 15 +- .../rl4j/learning/sync/ExpReplay.java | 16 +- .../rl4j/learning/sync/SyncLearning.java | 4 - .../learning/sync/qlearning/QLearning.java | 21 +- .../qlearning/discrete/QLearningDiscrete.java | 9 +- .../rl4j/mdp/CartpoleNative.java | 10 +- .../deeplearning4j/rl4j/policy/ACPolicy.java | 46 ++-- .../rl4j/policy/BoltzmannQ.java | 15 +- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 7 +- .../learning/async/AsyncLearningTest.java | 1 - .../rl4j/learning/sync/ExpReplayTest.java | 180 ++++++++++++++++ .../rl4j/learning/sync/SyncLearningTest.java | 1 - .../discrete/QLearningDiscreteTest.java | 10 +- .../rl4j/policy/PolicyTest.java | 4 +- .../rl4j/support/MockAsyncConfiguration.java | 2 +- .../rl4j/support/MockRandom.java | 203 ++++++++++++++++++ 23 files changed, 504 insertions(+), 96 deletions(-) create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index e243bdc5e..e6c803bd2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -42,7 +42,7 @@ public interface ILearning> ex interface LConfiguration { - int getSeed(); + Integer getSeed(); int getMaxEpochStep(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index 89c7fdb59..04ff06bc6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -29,8 +29,6 @@ import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import java.util.Random; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/27/16. * @@ -43,8 +41,7 @@ import java.util.Random; @Slf4j public abstract class Learning, NN extends NeuralNet> implements ILearning, NeuralNetFetchable { - @Getter - final private Random random; + @Getter @Setter private int stepCounter = 0; @Getter @Setter @@ -52,10 +49,6 @@ public abstract class Learning @Getter @Setter private IHistoryProcessor historyProcessor = null; - public Learning(LConfiguration conf) { - random = new Random(conf.getSeed()); - } - public static Integer getMaxAction(INDArray vector) { return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java index 2cfa31870..0727db475 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java @@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.learning.ILearning; */ public interface AsyncConfiguration extends ILearning.LConfiguration { - int getSeed(); + Integer getSeed(); int getMaxEpochStep(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 40d279182..0835bf692 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -42,10 +42,6 @@ public abstract class AsyncLearning extends AsyncLearning policy; public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf) { - super(conf); this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; - policy = new ACPolicy<>(iActorCritic, getRandom()); asyncGlobal = new AsyncGlobal<>(iActorCritic, conf); - mdp.getActionSpace().setSeed(conf.getSeed()); + + Integer seed = conf.getSeed(); + Random rnd = Nd4j.getRandom(); + if(seed != null) { + mdp.getActionSpace().setSeed(seed); + rnd.setSeed(seed); + } + + policy = new ACPolicy<>(iActorCritic, rnd); } protected AsyncThread newThread(int i, int deviceNum) { @@ -71,7 +79,7 @@ public abstract class A3CDiscrete extends AsyncLearning extends AsyncThreadDiscrete< @Getter final protected int threadNumber; - final private Random random; + final private Random rnd; public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, @@ -59,13 +59,18 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); - random = new Random(conf.getSeed() + threadNumber); + + Integer seed = conf.getSeed(); + rnd = Nd4j.getRandom(); + if(seed != null) { + mdp.getActionSpace().setSeed(seed + threadNumber); + rnd.setSeed(seed); + } } @Override protected Policy getPolicy(IActorCritic net) { - return new ACPolicy(net, random); + return new ACPolicy(net, rnd); } /** diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index cef53543a..1b423d1a2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -43,11 +43,13 @@ public abstract class AsyncNStepQLearningDiscrete public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { - super(conf); this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf); - mdp.getActionSpace().setSeed(conf.getSeed()); + Integer seed = conf.getSeed(); + if(seed != null) { + mdp.getActionSpace().setSeed(seed); + } } @Override @@ -70,7 +72,7 @@ public abstract class AsyncNStepQLearningDiscrete @EqualsAndHashCode(callSuper = false) public static class AsyncNStepQLConfiguration implements AsyncConfiguration { - int seed; + Integer seed; int maxEpochStep; int maxStep; int numThread; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 6bd1c8b6d..4a51c91d2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -32,8 +32,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.api.rng.Random; -import java.util.Random; import java.util.Stack; /** @@ -48,7 +48,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn @Getter final protected int threadNumber; - final private Random random; + final private Random rnd; public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, @@ -57,13 +57,18 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn this.conf = conf; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - mdp.getActionSpace().setSeed(conf.getSeed() + threadNumber); - random = new Random(conf.getSeed() + threadNumber); + rnd = Nd4j.getRandom(); + + Integer seed = conf.getSeed(); + if(seed != null) { + mdp.getActionSpace().setSeed(seed + threadNumber); + rnd.setSeed(seed + threadNumber); + } } public Policy getPolicy(IDQN nn) { return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), - random, conf.getMinEpsilon(), this); + rnd, conf.getMinEpsilon(), this); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index fd7d9465b..2defc1d75 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -20,9 +20,9 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.queue.CircularFifoQueue; +import org.nd4j.linalg.api.rng.Random; -import java.util.*; -import java.util.concurrent.ThreadLocalRandom; +import java.util.ArrayList; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. @@ -36,30 +36,28 @@ import java.util.concurrent.ThreadLocalRandom; public class ExpReplay implements IExpReplay { final private int batchSize; - final private Random random; + final private Random rnd; //Implementing this as a circular buffer queue private CircularFifoQueue> storage; - public ExpReplay(int maxSize, int batchSize, int seed) { + public ExpReplay(int maxSize, int batchSize, Random rnd) { this.batchSize = batchSize; - this.random = new Random(seed); + this.rnd = rnd; storage = new CircularFifoQueue<>(maxSize); } - public ArrayList> getBatch(int size) { ArrayList> batch = new ArrayList<>(size); int storageSize = storage.size(); int actualBatchSize = Math.min(storageSize, size); int[] actualIndex = new int[actualBatchSize]; - ThreadLocalRandom r = ThreadLocalRandom.current(); IntSet set = new IntOpenHashSet(); for( int i=0; i expReplay; public QLearning(QLConfiguration conf) { - super(conf); - expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed()); + this(conf, getSeededRandom(conf.getSeed())); + } + + public QLearning(QLConfiguration conf, Random random) { + expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); + } + + private static Random getSeededRandom(Integer seed) { + Random rnd = Nd4j.getRandom(); + if(seed != null) { + rnd.setSeed(seed); + } + + return rnd; } protected abstract EpsGreedy getEgPolicy(); @@ -160,7 +173,7 @@ public abstract class QLearning extends QLearning mdp, IDQN dqn, QLConfiguration conf, int epsilonNbStep) { + this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); + } + + public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, + int epsilonNbStep, Random random) { super(conf); this.configuration = conf; this.mdp = mdp; qNetwork = dqn; targetQNetwork = dqn.clone(); policy = new DQNPolicy(getQNetwork()); - egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(), + egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(), this); mdp.getActionSpace().setSeed(conf.getSeed()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java index e699075e3..424dbd57b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java @@ -57,7 +57,7 @@ public class CartpoleNative implements MDP observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); + public CartpoleNative() { + rnd = new Random(); + } + + public CartpoleNative(int seed) { + rnd = new Random(seed); + } + @Override public State reset() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index 078d4a5e1..09e396ac4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -16,18 +16,16 @@ package org.deeplearning4j.rl4j.policy; -import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; -import java.util.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. @@ -38,47 +36,41 @@ import java.util.Random; */ public class ACPolicy extends Policy { - final private IActorCritic IActorCritic; - Random rd; + final private IActorCritic actorCritic; + Random rnd; - public ACPolicy(IActorCritic IActorCritic) { - this.IActorCritic = IActorCritic; - NeuralNetwork nn = IActorCritic.getNeuralNetworks()[0]; - if (nn instanceof ComputationGraph) { - rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed()); - } else if (nn instanceof MultiLayerNetwork) { - rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed()); - } + public ACPolicy(IActorCritic actorCritic) { + this(actorCritic, Nd4j.getRandom()); } - public ACPolicy(IActorCritic IActorCritic, Random rd) { - this.IActorCritic = IActorCritic; - this.rd = rd; + public ACPolicy(IActorCritic actorCritic, Random rnd) { + this.actorCritic = actorCritic; + this.rnd = rnd; } public static ACPolicy load(String path) throws IOException { return new ACPolicy(ActorCriticCompGraph.load(path)); } - public static ACPolicy load(String path, Random rd) throws IOException { - return new ACPolicy(ActorCriticCompGraph.load(path), rd); + public static ACPolicy load(String path, Random rnd) throws IOException { + return new ACPolicy(ActorCriticCompGraph.load(path), rnd); } public static ACPolicy load(String pathValue, String pathPolicy) throws IOException { return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy)); } - public static ACPolicy load(String pathValue, String pathPolicy, Random rd) throws IOException { - return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy), rd); + public static ACPolicy load(String pathValue, String pathPolicy, Random rnd) throws IOException { + return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); } public IActorCritic getNeuralNet() { - return IActorCritic; + return actorCritic; } public Integer nextAction(INDArray input) { - INDArray output = IActorCritic.outputAll(input)[1]; - if (rd == null) { + INDArray output = actorCritic.outputAll(input)[1]; + if (rnd == null) { return Learning.getMaxAction(output); } - float rVal = rd.nextFloat(); + float rVal = rnd.nextFloat(); for (int i = 0; i < output.length(); i++) { //System.out.println(i + " " + rVal + " " + output.getFloat(i)); if (rVal < output.getFloat(i)) { @@ -91,11 +83,11 @@ public class ACPolicy extends Policy { } public void save(String filename) throws IOException { - IActorCritic.save(filename); + actorCritic.save(filename); } public void save(String filenameValue, String filenamePolicy) throws IOException { - IActorCritic.save(filenameValue, filenamePolicy); + actorCritic.save(filenameValue, filenamePolicy); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index d9858b89a..6ed7d4557 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -16,12 +16,11 @@ package org.deeplearning4j.rl4j.policy; -import lombok.AllArgsConstructor; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.Random; +import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.factory.Nd4j; import static org.nd4j.linalg.ops.transforms.Transforms.exp; @@ -31,11 +30,15 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; * Boltzmann exploration is a stochastic policy wrt to the * exponential Q-values as evaluated by the dqn model. */ -@AllArgsConstructor public class BoltzmannQ extends Policy { final private IDQN dqn; - final private Random rd = new Random(123); + final private Random rnd; + + public BoltzmannQ(IDQN dqn, Random random) { + this.dqn = dqn; + this.rnd = random; + } public IDQN getNeuralNet() { return dqn; @@ -47,7 +50,7 @@ public class BoltzmannQ extends Policy { INDArray exp = exp(output); double sum = exp.sum(1).getDouble(0); - double picked = rd.nextDouble() * sum; + double picked = rnd.nextDouble() * sum; for (int i = 0; i < exp.columns(); i++) { if (picked < exp.getDouble(i)) return i; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index dc954395d..a7be53596 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -24,8 +24,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.Random; +import org.nd4j.linalg.api.rng.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16. @@ -45,7 +44,7 @@ public class EpsGreedy> extend final private MDP mdp; final private int updateStart; final private int epsilonNbStep; - final private Random rd; + final private Random rnd; final private float minEpsilon; final private StepCountable learning; @@ -58,7 +57,7 @@ public class EpsGreedy> extend float ep = getEpsilon(); if (learning.getStepCounter() % 500 == 1) log.info("EP: " + ep + " " + learning.getStepCounter()); - if (rd.nextFloat() > ep) + if (rnd.nextFloat() > ep) return policy.nextAction(input); else return mdp.getActionSpace().randomAction(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java index ec0bca94f..536c6a8ad 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java @@ -87,7 +87,6 @@ public class AsyncLearningTest { private final IPolicy policy; public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) { - super(conf); this.conf = conf; this.asyncGlobal = asyncGlobal; this.policy = policy; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java new file mode 100644 index 000000000..44271adde --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -0,0 +1,180 @@ +package org.deeplearning4j.rl4j.learning.sync; + +import org.deeplearning4j.rl4j.support.MockRandom; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class ExpReplayTest { + @Test + public void when_storingElementWithStorageNotFull_expect_elementStored() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0 }); + ExpReplay sut = new ExpReplay(2, 1, randomMock); + + // Act + Transition transition = new Transition(new INDArray[] { Nd4j.create(1) }, 123, 234, false, Nd4j.create(1)); + sut.store(transition); + List> results = sut.getBatch(1); + + // Assert + assertEquals(1, results.size()); + assertEquals(123, (int)results.get(0).getAction()); + assertEquals(234, (int)results.get(0).getReward()); + } + + @Test + public void when_storingElementWithStorageFull_expect_oldestElementReplacedByStored() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0, 1 }); + ExpReplay sut = new ExpReplay(2, 1, randomMock); + + // Act + Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); + Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); + Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + sut.store(transition1); + sut.store(transition2); + sut.store(transition3); + List> results = sut.getBatch(2); + + // Assert + assertEquals(2, results.size()); + + assertEquals(3, (int)results.get(0).getAction()); + assertEquals(4, (int)results.get(0).getReward()); + + assertEquals(5, (int)results.get(1).getAction()); + assertEquals(6, (int)results.get(1).getReward()); + } + + + @Test + public void when_askBatchSizeZeroAndStorageEmpty_expect_emptyBatch() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0 }); + ExpReplay sut = new ExpReplay(5, 1, randomMock); + + // Act + List> results = sut.getBatch(0); + + // Assert + assertEquals(0, results.size()); + } + + @Test + public void when_askBatchSizeZeroAndStorageNotEmpty_expect_emptyBatch() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0 }); + ExpReplay sut = new ExpReplay(5, 1, randomMock); + + // Act + Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); + Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); + Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + sut.store(transition1); + sut.store(transition2); + sut.store(transition3); + List> results = sut.getBatch(0); + + // Assert + assertEquals(0, results.size()); + } + + @Test + public void when_askBatchSizeGreaterThanStoredCount_expect_batchWithStoredCountElements() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2 }); + ExpReplay sut = new ExpReplay(5, 1, randomMock); + + // Act + Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); + Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); + Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + sut.store(transition1); + sut.store(transition2); + sut.store(transition3); + List> results = sut.getBatch(10); + + // Assert + assertEquals(3, results.size()); + + assertEquals(1, (int)results.get(0).getAction()); + assertEquals(2, (int)results.get(0).getReward()); + + assertEquals(3, (int)results.get(1).getAction()); + assertEquals(4, (int)results.get(1).getReward()); + + assertEquals(5, (int)results.get(2).getAction()); + assertEquals(6, (int)results.get(2).getReward()); + } + + @Test + public void when_askBatchSizeSmallerThanStoredCount_expect_batchWithAskedElements() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 3, 4 }); + ExpReplay sut = new ExpReplay(5, 1, randomMock); + + // Act + Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); + Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); + Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + Transition transition4 = new Transition(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); + Transition transition5 = new Transition(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); + sut.store(transition1); + sut.store(transition2); + sut.store(transition3); + sut.store(transition4); + sut.store(transition5); + List> results = sut.getBatch(3); + + // Assert + assertEquals(3, results.size()); + + assertEquals(1, (int)results.get(0).getAction()); + assertEquals(2, (int)results.get(0).getReward()); + + assertEquals(3, (int)results.get(1).getAction()); + assertEquals(4, (int)results.get(1).getReward()); + + assertEquals(5, (int)results.get(2).getAction()); + assertEquals(6, (int)results.get(2).getReward()); + } + + @Test + public void when_randomGivesDuplicates_expect_noDuplicatesInBatch() { + // Arrange + MockRandom randomMock = new MockRandom(null, new int[] { 0, 1, 2, 1, 3, 1, 4 }); + ExpReplay sut = new ExpReplay(5, 1, randomMock); + + // Act + Transition transition1 = new Transition(new INDArray[] { Nd4j.create(1) }, 1, 2, false, Nd4j.create(1)); + Transition transition2 = new Transition(new INDArray[] { Nd4j.create(1) }, 3, 4, false, Nd4j.create(1)); + Transition transition3 = new Transition(new INDArray[] { Nd4j.create(1) }, 5, 6, false, Nd4j.create(1)); + Transition transition4 = new Transition(new INDArray[] { Nd4j.create(1) }, 7, 8, false, Nd4j.create(1)); + Transition transition5 = new Transition(new INDArray[] { Nd4j.create(1) }, 9, 10, false, Nd4j.create(1)); + sut.store(transition1); + sut.store(transition2); + sut.store(transition3); + sut.store(transition4); + sut.store(transition5); + List> results = sut.getBatch(3); + + // Assert + assertEquals(3, results.size()); + + assertEquals(1, (int)results.get(0).getAction()); + assertEquals(2, (int)results.get(0).getReward()); + + assertEquals(3, (int)results.get(1).getAction()); + assertEquals(4, (int)results.get(1).getReward()); + + assertEquals(5, (int)results.get(2).getAction()); + assertEquals(6, (int)results.get(2).getReward()); + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java index 7e7c3eb01..9b89390d8 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java @@ -89,7 +89,6 @@ public class SyncLearningTest { private final LConfiguration conf; public MockSyncLearning(LConfiguration conf) { - super(conf); this.conf = conf; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 2982a1d21..b5212566d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -13,6 +13,7 @@ import org.deeplearning4j.rl4j.util.IDataManager; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; @@ -27,11 +28,12 @@ public class QLearningDiscreteTest { MockObservationSpace observationSpace = new MockObservationSpace(); MockMDP mdp = new MockMDP(observationSpace); MockDQN dqn = new MockDQN(); + MockRandom random = new MockRandom(new double[] { 0.7309677600860596, 0.8314409852027893, 0.2405363917350769, 0.6063451766967773, 0.6374173760414124, 0.3090505599975586, 0.5504369735717773, 0.11700659990310669 }, null); QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, 0, 1.0, 0, 0, 0, 0, true); MockDataManager dataManager = new MockDataManager(false); MockExpReplay expReplay = new MockExpReplay(); - TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10); + TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); @@ -130,10 +132,10 @@ public class QLearningDiscreteTest { } public static class TestQLearningDiscrete extends QLearningDiscrete { - public TestQLearningDiscrete(MDP mdp,IDQN dqn, + public TestQLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, - int epsilonNbStep) { - super(mdp, dqn, conf, epsilonNbStep); + int epsilonNbStep, Random rnd) { + super(mdp, dqn, conf, epsilonNbStep, rnd); addListener(new DataManagerTrainingListener(dataManager)); setExpReplay(expReplay); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index c2f237c53..2dacd88e1 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -127,10 +127,10 @@ public class PolicyTest { .layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build()); ACPolicy policy = new ACPolicy(new DummyAC(cg)); - assertNotNull(policy.rd); + assertNotNull(policy.rnd); policy = new ACPolicy(new DummyAC(mln)); - assertNotNull(policy.rd); + assertNotNull(policy.rnd); INDArray input = Nd4j.create(new double[] {1.0, 0.0}, new long[]{1,2}); for (int i = 0; i < 100; i++) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java index a40de0e91..1706dc49e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java @@ -14,7 +14,7 @@ public class MockAsyncConfiguration implements AsyncConfiguration { } @Override - public int getSeed() { + public Integer getSeed() { return 0; } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java new file mode 100644 index 000000000..53de05bd9 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java @@ -0,0 +1,203 @@ +package org.deeplearning4j.rl4j.support; + +import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class MockRandom implements org.nd4j.linalg.api.rng.Random { + + private int randomDoubleValuesIdx = 0; + private final double[] randomDoubleValues; + + private int randomIntValuesIdx = 0; + private final int[] randomIntValues; + + public MockRandom(double[] randomDoubleValues, int[] randomIntValues) { + this.randomDoubleValues = randomDoubleValues; + this.randomIntValues = randomIntValues; + } + + @Override + public void setSeed(int i) { + + } + + @Override + public void setSeed(int[] ints) { + + } + + @Override + public void setSeed(long l) { + + } + + @Override + public long getSeed() { + return 0; + } + + @Override + public void nextBytes(byte[] bytes) { + + } + + @Override + public int nextInt() { + return randomIntValues[randomIntValuesIdx++]; + } + + @Override + public int nextInt(int i) { + return randomIntValues[randomIntValuesIdx++]; + } + + @Override + public int nextInt(int i, int i1) { + return randomIntValues[randomIntValuesIdx++]; + } + + @Override + public long nextLong() { + return randomIntValues[randomIntValuesIdx++]; + } + + @Override + public boolean nextBoolean() { + return false; + } + + @Override + public float nextFloat() { + return (float)randomDoubleValues[randomDoubleValuesIdx++]; + } + + @Override + public double nextDouble() { + return randomDoubleValues[randomDoubleValuesIdx++]; + } + + @Override + public double nextGaussian() { + return 0; + } + + @Override + public INDArray nextGaussian(int[] ints) { + return null; + } + + @Override + public INDArray nextGaussian(long[] longs) { + return null; + } + + @Override + public INDArray nextGaussian(char c, int[] ints) { + return null; + } + + @Override + public INDArray nextGaussian(char c, long[] longs) { + return null; + } + + @Override + public INDArray nextDouble(int[] ints) { + return null; + } + + @Override + public INDArray nextDouble(long[] longs) { + return null; + } + + @Override + public INDArray nextDouble(char c, int[] ints) { + return null; + } + + @Override + public INDArray nextDouble(char c, long[] longs) { + return null; + } + + @Override + public INDArray nextFloat(int[] ints) { + return null; + } + + @Override + public INDArray nextFloat(long[] longs) { + return null; + } + + @Override + public INDArray nextFloat(char c, int[] ints) { + return null; + } + + @Override + public INDArray nextFloat(char c, long[] longs) { + return null; + } + + @Override + public INDArray nextInt(int[] ints) { + return null; + } + + @Override + public INDArray nextInt(long[] longs) { + return null; + } + + @Override + public INDArray nextInt(int i, int[] ints) { + return null; + } + + @Override + public INDArray nextInt(int i, long[] longs) { + return null; + } + + @Override + public Pointer getStatePointer() { + return null; + } + + @Override + public long getPosition() { + return 0; + } + + @Override + public void reSeed() { + + } + + @Override + public void reSeed(long l) { + + } + + @Override + public long rootState() { + return 0; + } + + @Override + public long nodeState() { + return 0; + } + + @Override + public void setStates(long l, long l1) { + + } + + @Override + public void close() throws Exception { + + } +}