Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use latest black #116

Merged
merged 2 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions .pfnci/lint.sh
Expand Up @@ -2,8 +2,7 @@

set -eux

# Use latest black to apply https://github.com/psf/black/issues/1288
pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy isort
pip3 install black flake8 mypy isort

black --diff --check pfrl tests examples
isort --diff --check pfrl tests examples
Expand Down
5 changes: 4 additions & 1 deletion examples/atari/reproduction/a3c/train_a3c.py
Expand Up @@ -120,7 +120,10 @@ def make_env(process_idx, test):
nn.Linear(2592, 256),
nn.ReLU(),
pfrl.nn.Branched(
nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
nn.Linear(256, 1),
),
)
Expand Down
16 changes: 13 additions & 3 deletions examples/atari/reproduction/iqn/train_iqn.py
Expand Up @@ -124,8 +124,15 @@ def make_env(test):
nn.ReLU(),
nn.Flatten(),
),
phi=nn.Sequential(pfrl.agents.iqn.CosineBasisLinear(64, 3136), nn.ReLU(),),
f=nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear(512, n_actions),),
phi=nn.Sequential(
pfrl.agents.iqn.CosineBasisLinear(64, 3136),
nn.ReLU(),
),
f=nn.Sequential(
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
),
)

# Use the same hyper parameters as https://arxiv.org/abs/1710.10044
Expand Down Expand Up @@ -175,7 +182,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None,
env=eval_env,
agent=agent,
n_steps=args.eval_n_steps,
n_episodes=None,
)
print(
"n_steps: {} mean: {} median: {} stdev {}".format(
Expand Down
7 changes: 6 additions & 1 deletion examples/atari/reproduction/rainbow/train_rainbow.py
Expand Up @@ -110,7 +110,12 @@ def make_env(test):
n_atoms = 51
v_max = 10
v_min = -10
q_func = DistributionalDuelingDQN(n_actions, n_atoms, v_min, v_max,)
q_func = DistributionalDuelingDQN(
n_actions,
n_atoms,
v_min,
v_max,
)

# Noisy nets
pnn.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma)
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/train_a2c_ale.py
Expand Up @@ -136,12 +136,18 @@ def make_batch_env(test):
nn.Linear(2592, 256),
nn.ReLU(),
pfrl.nn.Branched(
nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
nn.Linear(256, 1),
),
)
optimizer = pfrl.optimizers.RMSpropEpsInsideSqrt(
model.parameters(), lr=args.lr, eps=args.rmsprop_epsilon, alpha=args.alpha,
model.parameters(),
lr=args.lr,
eps=args.rmsprop_epsilon,
alpha=args.alpha,
)

agent = a2c.A2C(
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/train_acer_ale.py
Expand Up @@ -105,8 +105,14 @@ def main():
)

head = acer.ACERDiscreteActionHead(
pi=nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
q=nn.Sequential(nn.Linear(256, n_actions), DiscreteActionValueHead(),),
pi=nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
q=nn.Sequential(
nn.Linear(256, n_actions),
DiscreteActionValueHead(),
),
)

if args.use_lstm:
Expand Down
5 changes: 4 additions & 1 deletion examples/atari/train_drqn_ale.py
Expand Up @@ -275,7 +275,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=None, n_episodes=args.demo_n_episodes,
env=eval_env,
agent=agent,
n_steps=None,
n_episodes=args.demo_n_episodes,
)
print(
"n_runs: {} mean: {} median: {} stdev {}".format(
Expand Down
5 changes: 4 additions & 1 deletion examples/grasping/train_dqn_batch_grasping.py
Expand Up @@ -53,7 +53,10 @@ def __init__(self, env, max_steps):
self._max_steps = max_steps
self._elapsed_steps = 0
self.observation_space = gym.spaces.Tuple(
(env.observation_space, gym.spaces.Discrete(self._max_steps + 1),)
(
env.observation_space,
gym.spaces.Discrete(self._max_steps + 1),
)
)

def reset(self):
Expand Down
26 changes: 20 additions & 6 deletions examples/optuna/optuna_dqn_obs1d.py
Expand Up @@ -222,7 +222,9 @@ def suggest(trial, steps):
max(1e3, hyperparams["decay_steps"] // 2), rbuf_capacity
)
hyperparams["replay_start_size"] = trial.suggest_int(
"replay_start_size", min_replay_start_size, max_replay_start_size,
"replay_start_size",
min_replay_start_size,
max_replay_start_size,
)
# target_update_interval should be a multiple of update_interval
hyperparams["update_interval"] = trial.suggest_int("update_interval", 1, 8)
Expand All @@ -239,7 +241,10 @@ def main():

# training parameters
parser.add_argument(
"--env", type=str, default="LunarLander-v2", help="OpenAI Gym Environment ID.",
"--env",
type=str,
default="LunarLander-v2",
help="OpenAI Gym Environment ID.",
)
parser.add_argument(
"--outdir",
Expand All @@ -251,7 +256,10 @@ def main():
),
)
parser.add_argument(
"--seed", type=int, default=0, help="Random seed for randomizer.",
"--seed",
type=int,
default=0,
help="Random seed for randomizer.",
)
parser.add_argument(
"--monitor",
Expand Down Expand Up @@ -289,7 +297,10 @@ def main():
help="Frequency (in timesteps) of evaluation phase.",
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Training batch size.",
"--batch-size",
type=int,
default=64,
help="Training batch size.",
)

# Optuna related args
Expand Down Expand Up @@ -361,7 +372,9 @@ def main():
help="Setting percentile == 50.0 is equivalent to the MedianPruner.",
)
parser.add_argument(
"--n-startup-trials", type=int, default=5,
"--n-startup-trials",
type=int,
default=5,
)
parser.add_argument(
"--n-warmup-steps",
Expand Down Expand Up @@ -417,7 +430,8 @@ def objective(trial):
pruner = optuna.pruners.NopPruner()
elif args.optuna_pruner == "ThresholdPruner":
pruner = optuna.pruners.ThresholdPruner(
lower=args.lower, n_warmup_steps=args.n_warmup_steps,
lower=args.lower,
n_warmup_steps=args.n_warmup_steps,
)
elif args.optuna_pruner == "PercentilePruner":
pruner = optuna.pruners.PercentilePruner(
Expand Down
5 changes: 4 additions & 1 deletion examples/slimevolley/train_rainbow.py
Expand Up @@ -210,7 +210,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_episodes,
env=eval_env,
agent=agent,
n_steps=None,
n_episodes=args.eval_n_episodes,
)
print(
"n_episodes: {} mean: {} median: {} stdev {}".format(
Expand Down
23 changes: 16 additions & 7 deletions pfrl/action_value.py
Expand Up @@ -174,7 +174,9 @@ def params(self):

def __getitem__(self, i):
return DistributionalDiscreteActionValue(
self.q_dist[i], self.z_values, q_values_formatter=self.q_values_formatter,
self.q_dist[i],
self.z_values,
q_values_formatter=self.q_values_formatter,
)


Expand Down Expand Up @@ -209,9 +211,11 @@ def evaluate_actions_as_quantiles(self, actions):
]

def __repr__(self):
return "QuantileDiscreteActionValue greedy_actions:{} q_values:{}".format( # NOQA
self.greedy_actions.detach().cpu().numpy(),
self.q_values_formatter(self.q_values.detach().cpu().numpy()),
return (
"QuantileDiscreteActionValue greedy_actions:{} q_values:{}".format( # NOQA
self.greedy_actions.detach().cpu().numpy(),
self.q_values_formatter(self.q_values.detach().cpu().numpy()),
)
)

@property
Expand All @@ -220,7 +224,8 @@ def params(self):

def __getitem__(self, i):
return QuantileDiscreteActionValue(
quantiles=self.quantiles[i], q_values_formatter=self.q_values_formatter,
quantiles=self.quantiles[i],
q_values_formatter=self.q_values_formatter,
)


Expand Down Expand Up @@ -276,7 +281,9 @@ def greedy_actions(self):
@lazy_property
def max(self):
if self.min_action is None and self.max_action is None:
return self.v.reshape(self.batch_size,)
return self.v.reshape(
self.batch_size,
)
else:
return self.evaluate_actions(self.greedy_actions)

Expand All @@ -288,7 +295,9 @@ def evaluate_actions(self, actions):
torch.matmul(u_minus_mu[:, None, :], self.mat), u_minus_mu[:, :, None]
)[:, 0, 0]
)
return a + self.v.reshape(self.batch_size,)
return a + self.v.reshape(
self.batch_size,
)

def compute_advantage(self, actions):
return self.evaluate_actions(actions) - self.max
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/a3c.py
Expand Up @@ -168,7 +168,9 @@ def update(self, statevar):
)
if self.recurrent:
(batch_distrib, batch_v), _ = pack_and_forward(
self.model, [batch_obs], self.past_recurrent_state[self.t_start],
self.model,
[batch_obs],
self.past_recurrent_state[self.t_start],
)
else:
batch_distrib, batch_v = self.model(batch_obs)
Expand Down
23 changes: 17 additions & 6 deletions pfrl/agents/acer.py
Expand Up @@ -188,13 +188,17 @@ def deepcopy_distribution(distrib):
"""
if isinstance(distrib, torch.distributions.Independent):
return torch.distributions.Independent(
deepcopy_distribution(distrib.base_dist), distrib.reinterpreted_batch_ndims,
deepcopy_distribution(distrib.base_dist),
distrib.reinterpreted_batch_ndims,
)
elif isinstance(distrib, torch.distributions.Categorical):
return torch.distributions.Categorical(logits=distrib.logits.clone().detach(),)
return torch.distributions.Categorical(
logits=distrib.logits.clone().detach(),
)
elif isinstance(distrib, torch.distributions.Normal):
return torch.distributions.Normal(
loc=distrib.loc.clone().detach(), scale=distrib.scale.clone().detach(),
loc=distrib.loc.clone().detach(),
scale=distrib.scale.clone().detach(),
)
else:
raise NotImplementedError("{} is not supported by ACER".format(type(distrib)))
Expand Down Expand Up @@ -624,7 +628,9 @@ def update_from_replay(self):
(avg_action_distrib, _, _),
shared_recurrent_state,
) = one_step_forward(
self.shared_average_model, bs, shared_recurrent_state,
self.shared_average_model,
bs,
shared_recurrent_state,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(bs)
Expand Down Expand Up @@ -731,7 +737,9 @@ def _act_train(self, obs):
(avg_action_distrib, _, _),
self.shared_recurrent_states,
) = one_step_forward(
self.shared_average_model, statevar, self.shared_recurrent_states,
self.shared_average_model,
statevar,
self.shared_recurrent_states,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(statevar)
Expand Down Expand Up @@ -789,7 +797,10 @@ def _observe_train(self, state, reward, done, reset):
self.past_rewards[self.t - 1] = reward
if self.process_idx == 0:
self.logger.debug(
"t:%s r:%s a:%s", self.t, reward, self.last_action,
"t:%s r:%s a:%s",
self.t,
reward,
self.last_action,
)

if self.t - self.t_start == self.t_max or done or reset:
Expand Down
12 changes: 9 additions & 3 deletions pfrl/agents/al.py
Expand Up @@ -27,7 +27,9 @@ def _compute_y_and_t(self, exp_batch):

if self.recurrent:
qout, _ = pack_and_forward(
self.model, batch_state, exp_batch["recurrent_state"],
self.model,
batch_state,
exp_batch["recurrent_state"],
)
else:
qout = self.model(batch_state)
Expand All @@ -42,7 +44,9 @@ def _compute_y_and_t(self, exp_batch):
with torch.no_grad():
if self.recurrent:
target_qout, _ = pack_and_forward(
self.target_model, batch_state, exp_batch["recurrent_state"],
self.target_model,
batch_state,
exp_batch["recurrent_state"],
)
target_next_qout, _ = pack_and_forward(
self.target_model,
Expand All @@ -53,7 +57,9 @@ def _compute_y_and_t(self, exp_batch):
target_qout = self.target_model(batch_state)
target_next_qout = self.target_model(batch_next_state)

next_q_max = target_next_qout.max.reshape(batch_size,)
next_q_max = target_next_qout.max.reshape(
batch_size,
)

batch_rewards = exp_batch["reward"]
batch_terminal = exp_batch["is_state_terminal"]
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/categorical_double_dqn.py
Expand Up @@ -24,7 +24,9 @@ def _compute_target_values(self, exp_batch):
exp_batch["next_recurrent_state"],
)
next_qout, _ = pack_and_forward(
self.model, batch_next_state, exp_batch["next_recurrent_state"],
self.model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
target_next_qout = self.target_model(batch_next_state)
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/categorical_dqn.py
Expand Up @@ -117,7 +117,9 @@ def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch["next_state"]
if self.recurrent:
target_next_qout, _ = pack_and_forward(
self.target_model, batch_next_state, exp_batch["next_recurrent_state"],
self.target_model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
target_next_qout = self.target_model(batch_next_state)
Expand Down