From 550e84ef43a32442dbd0d28f97ef164ba805f684 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Tue, 21 Apr 2020 20:13:08 -0400 Subject: [PATCH] RL4J: Add Agent and Environment (#358) * Added Agent and Environment Signed-off-by: Alexandre Boulanger * Added headers Signed-off-by: Alexandre Boulanger * Fix compilation errors Signed-off-by: Samuel Audet --- .../org/deeplearning4j/rl4j/agent/Agent.java | 210 ++++++++ .../rl4j/agent/listener/AgentListener.java | 23 + .../agent/listener/AgentListenerList.java | 50 ++ .../rl4j/environment/ActionSchema.java | 9 + .../rl4j/environment/Environment.java | 11 + .../rl4j/environment/Schema.java | 8 + .../rl4j/environment/StepResult.java | 12 + .../rl4j/learning/ILearning.java | 2 +- .../rl4j/learning/async/AsyncThread.java | 2 +- .../learning/async/AsyncThreadDiscrete.java | 2 +- .../async/a3c/discrete/A3CThreadDiscrete.java | 2 +- .../discrete/AsyncNStepQLearningDiscrete.java | 2 +- .../AsyncNStepQLearningThreadDiscrete.java | 2 +- .../rl4j/mdp/CartpoleEnvironment.java | 129 +++++ .../deeplearning4j/rl4j/policy/ACPolicy.java | 2 +- .../rl4j/policy/BoltzmannQ.java | 2 +- .../deeplearning4j/rl4j/policy/DQNPolicy.java | 4 +- .../deeplearning4j/rl4j/policy/EpsGreedy.java | 4 +- .../deeplearning4j/rl4j/policy/IPolicy.java | 14 +- .../deeplearning4j/rl4j/policy/Policy.java | 13 +- .../deeplearning4j/rl4j/agent/AgentTest.java | 483 ++++++++++++++++++ .../async/AsyncThreadDiscreteTest.java | 2 +- .../rl4j/policy/PolicyTest.java | 13 +- .../rl4j/support/MockPolicy.java | 10 +- 24 files changed, 979 insertions(+), 32 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java new file mode 100644 index 000000000..1b4a2699d --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java @@ -0,0 +1,210 @@ +package org.deeplearning4j.rl4j.agent; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import org.deeplearning4j.rl4j.agent.listener.AgentListener; +import org.deeplearning4j.rl4j.agent.listener.AgentListenerList; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.nd4j.base.Preconditions; + +import java.util.Map; + +public class Agent { + @Getter + private final String id; + + @Getter + private final Environment environment; + + @Getter + private final IPolicy policy; + + private final TransformProcess transformProcess; + + protected final AgentListenerList listeners; + + private final Integer maxEpisodeSteps; + + @Getter(AccessLevel.PROTECTED) + private Observation observation; + + @Getter(AccessLevel.PROTECTED) + private ACTION lastAction; + + @Getter + private int episodeStepNumber; + + @Getter + private double reward; + + protected boolean canContinue; + + private Agent(Builder builder) { + this.environment = builder.environment; + this.transformProcess = builder.transformProcess; + this.policy = builder.policy; + this.maxEpisodeSteps = builder.maxEpisodeSteps; + this.id = builder.id; + + listeners = buildListenerList(); + } + + protected AgentListenerList buildListenerList() { + return new AgentListenerList(); + } + + public void addListener(AgentListener listener) { + listeners.add(listener); + } + + public void run() { + runEpisode(); + } + + protected void onBeforeEpisode() { + // Do Nothing + } + + protected void onAfterEpisode() { + // Do Nothing + } + + protected void runEpisode() { + reset(); + onBeforeEpisode(); + + canContinue = listeners.notifyBeforeEpisode(this); + + while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) { + performStep(); + } + + if(!canContinue) { + return; + } + + onAfterEpisode(); + } + + protected void reset() { + resetEnvironment(); + resetPolicy(); + reward = 0; + lastAction = getInitialAction(); + canContinue = true; + } + + protected void resetEnvironment() { + episodeStepNumber = 0; + Map channelsData = environment.reset(); + this.observation = transformProcess.transform(channelsData, episodeStepNumber, false); + } + + protected void resetPolicy() { + policy.reset(); + } + + protected ACTION getInitialAction() { + return environment.getSchema().getActionSchema().getNoOp(); + } + + protected void performStep() { + + onBeforeStep(); + + ACTION action = decideAction(observation); + + canContinue = listeners.notifyBeforeStep(this, observation, action); + if(!canContinue) { + return; + } + + StepResult stepResult = act(action); + handleStepResult(stepResult); + + onAfterStep(stepResult); + + canContinue = listeners.notifyAfterStep(this, stepResult); + if(!canContinue) { + return; + } + + incrementEpisodeStepNumber(); + } + + protected void incrementEpisodeStepNumber() { + ++episodeStepNumber; + } + + protected ACTION decideAction(Observation observation) { + if (!observation.isSkipped()) { + lastAction = policy.nextAction(observation); + } + + return lastAction; + } + + protected StepResult act(ACTION action) { + return environment.step(action); + } + + protected void handleStepResult(StepResult stepResult) { + observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1); + reward +=computeReward(stepResult); + } + + protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) { + return transformProcess.transform(stepResult.getChannelsData(), episodeStepNumberOfObs, stepResult.isTerminal()); + } + + protected double computeReward(StepResult stepResult) { + return stepResult.getReward(); + } + + protected void onAfterStep(StepResult stepResult) { + // Do Nothing + } + + protected void onBeforeStep() { + // Do Nothing + } + + public static Builder builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { + return new Builder<>(environment, transformProcess, policy); + } + + public static class Builder { + private final Environment environment; + private final TransformProcess transformProcess; + private final IPolicy policy; + private Integer maxEpisodeSteps = null; // Default, no max + private String id; + + public Builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { + this.environment = environment; + this.transformProcess = transformProcess; + this.policy = policy; + } + + public Builder maxEpisodeSteps(int maxEpisodeSteps) { + Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps); + this.maxEpisodeSteps = maxEpisodeSteps; + + return this; + } + + public Builder id(String id) { + this.id = id; + return this; + } + + public Agent build() { + return new Agent(this); + } + } +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java new file mode 100644 index 000000000..898f89241 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java @@ -0,0 +1,23 @@ +package org.deeplearning4j.rl4j.agent.listener; + +import org.deeplearning4j.rl4j.agent.Agent; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; + +public interface AgentListener { + enum ListenerResponse { + /** + * Tell the learning process to continue calling the listeners and the training. + */ + CONTINUE, + + /** + * Tell the learning process to stop calling the listeners and terminate the training. + */ + STOP, + } + + AgentListener.ListenerResponse onBeforeEpisode(Agent agent); + AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action); + AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java new file mode 100644 index 000000000..e003934d4 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java @@ -0,0 +1,50 @@ +package org.deeplearning4j.rl4j.agent.listener; + +import org.deeplearning4j.rl4j.agent.Agent; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +public class AgentListenerList { + protected final List> listeners = new ArrayList<>(); + + /** + * Add a listener at the end of the list + * @param listener The listener to be added + */ + public void add(AgentListener listener) { + listeners.add(listener); + } + + public boolean notifyBeforeEpisode(Agent agent) { + for (AgentListener listener : listeners) { + if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + public boolean notifyBeforeStep(Agent agent, Observation observation, ACTION action) { + for (AgentListener listener : listeners) { + if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + public boolean notifyAfterStep(Agent agent, StepResult stepResult) { + for (AgentListener listener : listeners) { + if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java new file mode 100644 index 000000000..f6521e734 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java @@ -0,0 +1,9 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +@Value +public class ActionSchema { + private ACTION noOp; + //FIXME ACTION randomAction(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java new file mode 100644 index 000000000..95ff7d2b6 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java @@ -0,0 +1,11 @@ +package org.deeplearning4j.rl4j.environment; + +import java.util.Map; + +public interface Environment { + Schema getSchema(); + Map reset(); + StepResult step(ACTION action); + boolean isEpisodeFinished(); + void close(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java new file mode 100644 index 000000000..5ddea24cd --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java @@ -0,0 +1,8 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +@Value +public class Schema { + private ActionSchema actionSchema; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java new file mode 100644 index 000000000..b64dd08f5 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java @@ -0,0 +1,12 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +import java.util.Map; + +@Value +public class StepResult { + private Map channelsData; + private double reward; + private boolean terminal; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index 0d1f0ae20..db964527e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable; */ public interface ILearning> { - IPolicy getPolicy(); + IPolicy getPolicy(); void train(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 864683d79..54be00cfb 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -221,7 +221,7 @@ public abstract class AsyncThread getPolicy(NN net); + protected abstract IPolicy getPolicy(NN net); protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index fcce92a4a..f340e2706 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -97,7 +97,7 @@ public abstract class AsyncThreadDiscrete policy = getPolicy(current); + IPolicy policy = getPolicy(current); Integer action = getMdp().getActionSpace().noOp(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index adf68489e..36f973957 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -65,7 +65,7 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< } @Override - protected Policy getPolicy(IActorCritic net) { + protected Policy getPolicy(IActorCritic net) { return new ACPolicy(net, rnd); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index a4c0b643b..94edac593 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -62,7 +62,7 @@ public abstract class AsyncNStepQLearningDiscrete return asyncGlobal.getTarget(); } - public IPolicy getPolicy() { + public IPolicy getPolicy() { return new DQNPolicy(getNeuralNet()); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 34a2c07a4..ef60c685f 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -64,7 +64,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn setUpdateAlgorithm(buildUpdateAlgorithm()); } - public Policy getPolicy(IDQN nn) { + public Policy getPolicy(IDQN nn) { return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), rnd, conf.getMinEpsilon(), this); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java new file mode 100644 index 000000000..1e1348b4a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java @@ -0,0 +1,129 @@ +package org.deeplearning4j.rl4j.mdp; + +import lombok.Getter; +import lombok.Setter; +import org.deeplearning4j.rl4j.environment.ActionSchema; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.Schema; +import org.deeplearning4j.rl4j.environment.StepResult; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public class CartpoleEnvironment implements Environment { + private static final int NUM_ACTIONS = 2; + private static final int ACTION_LEFT = 0; + private static final int ACTION_RIGHT = 1; + + private static final Schema schema = new Schema<>(new ActionSchema<>(ACTION_LEFT)); + + public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; + + private static final double gravity = 9.8; + private static final double massCart = 1.0; + private static final double massPole = 0.1; + private static final double totalMass = massPole + massCart; + private static final double length = 0.5; // actually half the pole's length + private static final double polemassLength = massPole * length; + private static final double forceMag = 10.0; + private static final double tau = 0.02; // seconds between state updates + + // Angle at which to fail the episode + private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0; + private static final double xThreshold = 2.4; + + private final Random rnd; + + @Getter @Setter + private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler; + + @Getter + private boolean episodeFinished = false; + + private double x; + private double xDot; + private double theta; + private double thetaDot; + private Integer stepsBeyondDone; + + public CartpoleEnvironment() { + rnd = new Random(); + } + + public CartpoleEnvironment(int seed) { + rnd = new Random(seed); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public Map reset() { + + x = 0.1 * rnd.nextDouble() - 0.05; + xDot = 0.1 * rnd.nextDouble() - 0.05; + theta = 0.1 * rnd.nextDouble() - 0.05; + thetaDot = 0.1 * rnd.nextDouble() - 0.05; + stepsBeyondDone = null; + episodeFinished = false; + + return new HashMap() {{ + put("data", new double[]{x, xDot, theta, thetaDot}); + }}; + } + + @Override + public StepResult step(Integer action) { + double force = action == ACTION_RIGHT ? forceMag : -forceMag; + double cosTheta = Math.cos(theta); + double sinTheta = Math.sin(theta); + double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass; + double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass)); + double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass; + + switch(kinematicsIntegrator) { + case Euler: + x += tau * xDot; + xDot += tau * xAcc; + theta += tau * thetaDot; + thetaDot += tau * thetaAcc; + break; + + case SemiImplicitEuler: + xDot += tau * xAcc; + x += tau * xDot; + thetaDot += tau * thetaAcc; + theta += tau * thetaDot; + break; + } + + episodeFinished |= x < -xThreshold || x > xThreshold + || theta < -thetaThresholdRadians || theta > thetaThresholdRadians; + + double reward; + if(!episodeFinished) { + reward = 1.0; + } + else if(stepsBeyondDone == null) { + stepsBeyondDone = 0; + reward = 1.0; + } + else { + ++stepsBeyondDone; + reward = 0; + } + + Map channelsData = new HashMap() {{ + put("data", new double[]{x, xDot, theta, thetaDot}); + }}; + return new StepResult(channelsData, reward, episodeFinished); + } + + @Override + public void close() { + // Do nothing + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index 61ba70825..e01456729 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -35,7 +35,7 @@ import java.io.IOException; * the softmax output of the actor critic, but objects constructed * with a {@link Random} argument of null return the max only. */ -public class ACPolicy extends Policy { +public class ACPolicy extends Policy { final private IActorCritic actorCritic; Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index cf2b60f41..7508655c3 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; * Boltzmann exploration is a stochastic policy wrt to the * exponential Q-values as evaluated by the dqn model. */ -public class BoltzmannQ extends Policy { +public class BoltzmannQ extends Policy { final private IDQN dqn; final private Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java index c7ef91665..e2982823d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java @@ -32,8 +32,10 @@ import java.io.IOException; * DQN policy returns the action with the maximum Q-value as evaluated * by the dqn model */ + +// FIXME: Should we rename this "GreedyPolicy"? @AllArgsConstructor -public class DQNPolicy extends Policy { +public class DQNPolicy extends Policy { final private IDQN dqn; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 2c7695dc7..4801c7b70 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -41,9 +41,9 @@ import org.nd4j.linalg.api.rng.Random; */ @AllArgsConstructor @Slf4j -public class EpsGreedy> extends Policy { +public class EpsGreedy> extends Policy { - final private Policy policy; + final private Policy policy; final private MDP mdp; final private int updateStart; final private int epsilonNbStep; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index f87971a89..ffc029835 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -7,8 +7,14 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; -public interface IPolicy { - > double play(MDP mdp, IHistoryProcessor hp); - A nextAction(INDArray input); - A nextAction(Observation observation); +public interface IPolicy { + @Deprecated + > double play(MDP mdp, IHistoryProcessor hp); + + @Deprecated + ACTION nextAction(INDArray input); + + ACTION nextAction(Observation observation); + + void reset(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index d5fa59766..4885e2c62 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -34,22 +34,22 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy implements IPolicy { +public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); public abstract A nextAction(Observation obs); - public > double play(MDP mdp) { + public > double play(MDP mdp) { return play(mdp, (IHistoryProcessor)null); } - public > double play(MDP mdp, HistoryProcessor.Configuration conf) { + public > double play(MDP mdp, HistoryProcessor.Configuration conf) { return play(mdp, new HistoryProcessor(conf)); } @Override - public > double play(MDP mdp, IHistoryProcessor hp) { + public > double play(MDP mdp, IHistoryProcessor hp) { resetNetworks(); LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp); @@ -84,8 +84,11 @@ public abstract class Policy implements IPolicy { protected void resetNetworks() { getNeuralNet().reset(); } + public void reset() { + resetNetworks(); + } - protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { double reward = 0; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java new file mode 100644 index 000000000..a8beae640 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java @@ -0,0 +1,483 @@ +package org.deeplearning4j.rl4j.agent; + +import org.deeplearning4j.rl4j.agent.listener.AgentListener; +import org.deeplearning4j.rl4j.environment.ActionSchema; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.Schema; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.junit.Rule; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.mockito.*; +import org.mockito.junit.*; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +public class AgentTest { + + @Mock Environment environmentMock; + @Mock TransformProcess transformProcessMock; + @Mock IPolicy policyMock; + @Mock AgentListener listenerMock; + + @Rule + public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Test + public void when_buildingWithNullEnvironment_expect_exception() { + try { + Agent.builder(null, null, null).build(); + fail("NullPointerException should have been thrown"); + } catch (NullPointerException exception) { + String expectedMessage = "environment is marked non-null but is null"; + String actualMessage = exception.getMessage(); + + assertTrue(actualMessage.contains(expectedMessage)); + } + } + + @Test + public void when_buildingWithNullTransformProcess_expect_exception() { + try { + Agent.builder(environmentMock, null, null).build(); + fail("NullPointerException should have been thrown"); + } catch (NullPointerException exception) { + String expectedMessage = "transformProcess is marked non-null but is null"; + String actualMessage = exception.getMessage(); + + assertTrue(actualMessage.contains(expectedMessage)); + } + } + + @Test + public void when_buildingWithNullPolicy_expect_exception() { + try { + Agent.builder(environmentMock, transformProcessMock, null).build(); + fail("NullPointerException should have been thrown"); + } catch (NullPointerException exception) { + String expectedMessage = "policy is marked non-null but is null"; + String actualMessage = exception.getMessage(); + + assertTrue(actualMessage.contains(expectedMessage)); + } + } + + @Test + public void when_buildingWithInvalidMaxSteps_expect_exception() { + try { + Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(0) + .build(); + fail("IllegalArgumentException should have been thrown"); + } catch (IllegalArgumentException exception) { + String expectedMessage = "maxEpisodeSteps must be greater than 0, got [0]"; + String actualMessage = exception.getMessage(); + + assertTrue(actualMessage.contains(expectedMessage)); + } + } + + @Test + public void when_buildingWithId_expect_idSetInAgent() { + // Arrange + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .id("TestAgent") + .build(); + + // Assert + assertEquals("TestAgent", sut.getId()); + } + + @Test + public void when_runIsCalled_expect_agentIsReset() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + when(policyMock.nextAction(any(Observation.class))).thenReturn(1); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .build(); + + when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), anyInt())).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + // Act + sut.run(); + + // Assert + assertEquals(0, sut.getEpisodeStepNumber()); + verify(transformProcessMock).transform(envResetResult, 0, false); + verify(policyMock, times(1)).reset(); + assertEquals(0.0, sut.getReward(), 0.00001); + verify(environmentMock, times(1)).reset(); + } + + @Test + public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + when(environmentMock.isEpisodeFinished()).thenReturn(true); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build(); + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onBeforeEpisode(); + verify(spy, times(1)).onAfterEpisode(); + } + + @Test + public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build(); + + when(listenerMock.onBeforeEpisode(any(Agent.class))).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onBeforeEpisode(); + verify(spy, never()).performStep(); + verify(spy, never()).onAfterStep(any(StepResult.class)); + verify(spy, never()).onAfterEpisode(); + } + + @Test + public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .build(); + + final Agent spy = Mockito.spy(sut); + + doAnswer(invocation -> { + ((Agent)invocation.getMock()).incrementEpisodeStepNumber(); + return null; + }).when(spy).performStep(); + when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 ); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onBeforeEpisode(); + verify(spy, times(5)).performStep(); + verify(spy, times(1)).onAfterEpisode(); + } + + @Test + public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(3) + .build(); + + final Agent spy = Mockito.spy(sut); + + doAnswer(invocation -> { + ((Agent)invocation.getMock()).incrementEpisodeStepNumber(); + return null; + }).when(spy).performStep(); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onBeforeEpisode(); + verify(spy, times(3)).performStep(); + verify(spy, times(1)).onAfterEpisode(); + } + + @Test + public void when_initialObservationsAreSkipped_expect_performNoOpAction() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .build(); + + when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(listenerMock).onBeforeStep(any(), any(), eq(-1)); + } + + @Test + public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .build(); + + when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(listenerMock).onBeforeStep(any(), any(), eq(-1)); + } + + @Test + public void when_observationsIsSkipped_expect_performLastAction() { + // Arrange + Map envResetResult = new HashMap<>(); + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(envResetResult); + when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false)); + when(environmentMock.getSchema()).thenReturn(schema); + + when(policyMock.nextAction(any(Observation.class))) + .thenAnswer(invocation -> (int)((Observation)invocation.getArgument(0)).getData().getDouble(0)); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(3) + .build(); + + Agent spy = Mockito.spy(sut); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())) + .thenAnswer(invocation -> { + int stepNumber = (int)invocation.getArgument(1); + return stepNumber % 2 == 1 ? Observation.SkippedObservation + : new Observation(Nd4j.create(new double[] { stepNumber })); + }); + + sut.addListener(listenerMock); + + // Act + spy.run(); + + // Assert + verify(policyMock, times(2)).nextAction(any(Observation.class)); + + ArgumentCaptor agentCaptor = ArgumentCaptor.forClass(Agent.class); + ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class); + ArgumentCaptor actionCaptor = ArgumentCaptor.forClass(Integer.class); + verify(listenerMock, times(3)).onBeforeStep(agentCaptor.capture(), observationCaptor.capture(), actionCaptor.capture()); + List capturedActions = actionCaptor.getAllValues(); + assertEquals(0, (int)capturedActions.get(0)); + assertEquals(0, (int)capturedActions.get(1)); + assertEquals(2, (int)capturedActions.get(2)); + } + + @Test + public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build(); + + when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onBeforeEpisode(); + verify(spy, times(1)).onBeforeStep(); + verify(spy, never()).act(any()); + verify(spy, never()).onAfterStep(any(StepResult.class)); + verify(spy, never()).onAfterEpisode(); + } + + @Test + public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false)); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(1) + .build(); + + // Act + sut.run(); + + // Assert + verify(environmentMock, times(1)).step(123); + } + + @Test + public void when_stepResultIsReceived_expect_observationAndRewardUpdated() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false)); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(1) + .build(); + + // Act + sut.run(); + + // Assert + assertEquals(123.0, sut.getObservation().getData().getDouble(0), 0.00001); + assertEquals(234.0, sut.getReward(), 0.00001); + } + + @Test + public void when_stepIsDone_expect_onAfterStepAndWithStepResult() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); + when(environmentMock.step(any(Integer.class))).thenReturn(stepResult); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(1) + .build(); + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy).onAfterStep(stepResult); + } + + @Test + public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); + when(environmentMock.step(any(Integer.class))).thenReturn(stepResult); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(1) + .build(); + when(listenerMock.onAfterStep(any(Agent.class), any(StepResult.class))).thenReturn(AgentListener.ListenerResponse.STOP); + sut.addListener(listenerMock); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy, never()).onAfterEpisode(); + } + + @Test + public void when_runIsCalled_expect_onAfterEpisodeIsCalled() { + // Arrange + Schema schema = new Schema(new ActionSchema<>(-1)); + when(environmentMock.reset()).thenReturn(new HashMap<>()); + when(environmentMock.getSchema()).thenReturn(schema); + StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); + when(environmentMock.step(any(Integer.class))).thenReturn(stepResult); + + when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 }))); + + when(policyMock.nextAction(any(Observation.class))).thenReturn(123); + + Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock) + .maxEpisodeSteps(1) + .build(); + + Agent spy = Mockito.spy(sut); + + // Act + spy.run(); + + // Assert + verify(spy, times(1)).onAfterEpisode(); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 5f2a8ab31..9499da99e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -62,7 +62,7 @@ public class AsyncThreadDiscreteTest { IAsyncGlobal mockAsyncGlobal; @Mock - Policy mockGlobalCurrentPolicy; + Policy mockGlobalCurrentPolicy; @Mock NeuralNet mockGlobalTargetNetwork; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 7db92a599..403d3c91e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -30,13 +30,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.support.MockDQN; -import org.deeplearning4j.rl4j.support.MockEncodable; -import org.deeplearning4j.rl4j.support.MockHistoryProcessor; -import org.deeplearning4j.rl4j.support.MockMDP; -import org.deeplearning4j.rl4j.support.MockNeuralNet; -import org.deeplearning4j.rl4j.support.MockObservationSpace; -import org.deeplearning4j.rl4j.support.MockRandom; +import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -227,7 +222,7 @@ public class PolicyTest { assertEquals(0, dqn.outputParams.size()); } - public static class MockRefacPolicy extends Policy { + public static class MockRefacPolicy extends Policy { private NeuralNet neuralNet; private final int[] shape; @@ -257,7 +252,7 @@ public class PolicyTest { } @Override - protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); return super.refacInitMdp(mdpWrapper, hp); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java index 4c4f100e9..2b9fd491b 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java @@ -5,18 +5,19 @@ import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; import java.util.List; -public class MockPolicy implements IPolicy { +public class MockPolicy implements IPolicy { public int playCallCount = 0; public List actionInputs = new ArrayList(); @Override - public > double play(MDP mdp, IHistoryProcessor hp) { + public > double play(MDP mdp, IHistoryProcessor hp) { ++playCallCount; return 0; } @@ -31,4 +32,9 @@ public class MockPolicy implements IPolicy { public Integer nextAction(Observation observation) { return nextAction(observation.getData()); } + + @Override + public void reset() { + + } }