Skip to content

Commit

Permalink
Merge pull request #116 from muupan/update-black
Browse files Browse the repository at this point in the history
Use latest black
  • Loading branch information
ummavi committed Jan 6, 2021
2 parents 1745d84 + 5e08a09 commit 03a17d6
Show file tree
Hide file tree
Showing 51 changed files with 541 additions and 147 deletions.
3 changes: 1 addition & 2 deletions .pfnci/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

set -eux

# Use latest black to apply https://github.com/psf/black/issues/1288
pip3 install git+git://github.com/psf/black.git@88d12f88a97e5e4c8fd0d245df0a311e932fd1e1 flake8 mypy isort
pip3 install black flake8 mypy isort

black --diff --check pfrl tests examples
isort --diff --check pfrl tests examples
Expand Down
5 changes: 4 additions & 1 deletion examples/atari/reproduction/a3c/train_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def make_env(process_idx, test):
nn.Linear(2592, 256),
nn.ReLU(),
pfrl.nn.Branched(
nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
nn.Linear(256, 1),
),
)
Expand Down
16 changes: 13 additions & 3 deletions examples/atari/reproduction/iqn/train_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,15 @@ def make_env(test):
nn.ReLU(),
nn.Flatten(),
),
phi=nn.Sequential(pfrl.agents.iqn.CosineBasisLinear(64, 3136), nn.ReLU(),),
f=nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear(512, n_actions),),
phi=nn.Sequential(
pfrl.agents.iqn.CosineBasisLinear(64, 3136),
nn.ReLU(),
),
f=nn.Sequential(
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
),
)

# Use the same hyper parameters as https://arxiv.org/abs/1710.10044
Expand Down Expand Up @@ -175,7 +182,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None,
env=eval_env,
agent=agent,
n_steps=args.eval_n_steps,
n_episodes=None,
)
print(
"n_steps: {} mean: {} median: {} stdev {}".format(
Expand Down
7 changes: 6 additions & 1 deletion examples/atari/reproduction/rainbow/train_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def make_env(test):
n_atoms = 51
v_max = 10
v_min = -10
q_func = DistributionalDuelingDQN(n_actions, n_atoms, v_min, v_max,)
q_func = DistributionalDuelingDQN(
n_actions,
n_atoms,
v_min,
v_max,
)

# Noisy nets
pnn.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma)
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/train_a2c_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,18 @@ def make_batch_env(test):
nn.Linear(2592, 256),
nn.ReLU(),
pfrl.nn.Branched(
nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
nn.Linear(256, 1),
),
)
optimizer = pfrl.optimizers.RMSpropEpsInsideSqrt(
model.parameters(), lr=args.lr, eps=args.rmsprop_epsilon, alpha=args.alpha,
model.parameters(),
lr=args.lr,
eps=args.rmsprop_epsilon,
alpha=args.alpha,
)

agent = a2c.A2C(
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/train_acer_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,14 @@ def main():
)

head = acer.ACERDiscreteActionHead(
pi=nn.Sequential(nn.Linear(256, n_actions), SoftmaxCategoricalHead(),),
q=nn.Sequential(nn.Linear(256, n_actions), DiscreteActionValueHead(),),
pi=nn.Sequential(
nn.Linear(256, n_actions),
SoftmaxCategoricalHead(),
),
q=nn.Sequential(
nn.Linear(256, n_actions),
DiscreteActionValueHead(),
),
)

if args.use_lstm:
Expand Down
5 changes: 4 additions & 1 deletion examples/atari/train_drqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=None, n_episodes=args.demo_n_episodes,
env=eval_env,
agent=agent,
n_steps=None,
n_episodes=args.demo_n_episodes,
)
print(
"n_runs: {} mean: {} median: {} stdev {}".format(
Expand Down
5 changes: 4 additions & 1 deletion examples/grasping/train_dqn_batch_grasping.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(self, env, max_steps):
self._max_steps = max_steps
self._elapsed_steps = 0
self.observation_space = gym.spaces.Tuple(
(env.observation_space, gym.spaces.Discrete(self._max_steps + 1),)
(
env.observation_space,
gym.spaces.Discrete(self._max_steps + 1),
)
)

def reset(self):
Expand Down
26 changes: 20 additions & 6 deletions examples/optuna/optuna_dqn_obs1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ def suggest(trial, steps):
max(1e3, hyperparams["decay_steps"] // 2), rbuf_capacity
)
hyperparams["replay_start_size"] = trial.suggest_int(
"replay_start_size", min_replay_start_size, max_replay_start_size,
"replay_start_size",
min_replay_start_size,
max_replay_start_size,
)
# target_update_interval should be a multiple of update_interval
hyperparams["update_interval"] = trial.suggest_int("update_interval", 1, 8)
Expand All @@ -239,7 +241,10 @@ def main():

# training parameters
parser.add_argument(
"--env", type=str, default="LunarLander-v2", help="OpenAI Gym Environment ID.",
"--env",
type=str,
default="LunarLander-v2",
help="OpenAI Gym Environment ID.",
)
parser.add_argument(
"--outdir",
Expand All @@ -251,7 +256,10 @@ def main():
),
)
parser.add_argument(
"--seed", type=int, default=0, help="Random seed for randomizer.",
"--seed",
type=int,
default=0,
help="Random seed for randomizer.",
)
parser.add_argument(
"--monitor",
Expand Down Expand Up @@ -289,7 +297,10 @@ def main():
help="Frequency (in timesteps) of evaluation phase.",
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Training batch size.",
"--batch-size",
type=int,
default=64,
help="Training batch size.",
)

# Optuna related args
Expand Down Expand Up @@ -361,7 +372,9 @@ def main():
help="Setting percentile == 50.0 is equivalent to the MedianPruner.",
)
parser.add_argument(
"--n-startup-trials", type=int, default=5,
"--n-startup-trials",
type=int,
default=5,
)
parser.add_argument(
"--n-warmup-steps",
Expand Down Expand Up @@ -417,7 +430,8 @@ def objective(trial):
pruner = optuna.pruners.NopPruner()
elif args.optuna_pruner == "ThresholdPruner":
pruner = optuna.pruners.ThresholdPruner(
lower=args.lower, n_warmup_steps=args.n_warmup_steps,
lower=args.lower,
n_warmup_steps=args.n_warmup_steps,
)
elif args.optuna_pruner == "PercentilePruner":
pruner = optuna.pruners.PercentilePruner(
Expand Down
5 changes: 4 additions & 1 deletion examples/slimevolley/train_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def phi(x):

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_episodes,
env=eval_env,
agent=agent,
n_steps=None,
n_episodes=args.eval_n_episodes,
)
print(
"n_episodes: {} mean: {} median: {} stdev {}".format(
Expand Down
23 changes: 16 additions & 7 deletions pfrl/action_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def params(self):

def __getitem__(self, i):
return DistributionalDiscreteActionValue(
self.q_dist[i], self.z_values, q_values_formatter=self.q_values_formatter,
self.q_dist[i],
self.z_values,
q_values_formatter=self.q_values_formatter,
)


Expand Down Expand Up @@ -209,9 +211,11 @@ def evaluate_actions_as_quantiles(self, actions):
]

def __repr__(self):
return "QuantileDiscreteActionValue greedy_actions:{} q_values:{}".format( # NOQA
self.greedy_actions.detach().cpu().numpy(),
self.q_values_formatter(self.q_values.detach().cpu().numpy()),
return (
"QuantileDiscreteActionValue greedy_actions:{} q_values:{}".format( # NOQA
self.greedy_actions.detach().cpu().numpy(),
self.q_values_formatter(self.q_values.detach().cpu().numpy()),
)
)

@property
Expand All @@ -220,7 +224,8 @@ def params(self):

def __getitem__(self, i):
return QuantileDiscreteActionValue(
quantiles=self.quantiles[i], q_values_formatter=self.q_values_formatter,
quantiles=self.quantiles[i],
q_values_formatter=self.q_values_formatter,
)


Expand Down Expand Up @@ -276,7 +281,9 @@ def greedy_actions(self):
@lazy_property
def max(self):
if self.min_action is None and self.max_action is None:
return self.v.reshape(self.batch_size,)
return self.v.reshape(
self.batch_size,
)
else:
return self.evaluate_actions(self.greedy_actions)

Expand All @@ -288,7 +295,9 @@ def evaluate_actions(self, actions):
torch.matmul(u_minus_mu[:, None, :], self.mat), u_minus_mu[:, :, None]
)[:, 0, 0]
)
return a + self.v.reshape(self.batch_size,)
return a + self.v.reshape(
self.batch_size,
)

def compute_advantage(self, actions):
return self.evaluate_actions(actions) - self.max
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def update(self, statevar):
)
if self.recurrent:
(batch_distrib, batch_v), _ = pack_and_forward(
self.model, [batch_obs], self.past_recurrent_state[self.t_start],
self.model,
[batch_obs],
self.past_recurrent_state[self.t_start],
)
else:
batch_distrib, batch_v = self.model(batch_obs)
Expand Down
23 changes: 17 additions & 6 deletions pfrl/agents/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,17 @@ def deepcopy_distribution(distrib):
"""
if isinstance(distrib, torch.distributions.Independent):
return torch.distributions.Independent(
deepcopy_distribution(distrib.base_dist), distrib.reinterpreted_batch_ndims,
deepcopy_distribution(distrib.base_dist),
distrib.reinterpreted_batch_ndims,
)
elif isinstance(distrib, torch.distributions.Categorical):
return torch.distributions.Categorical(logits=distrib.logits.clone().detach(),)
return torch.distributions.Categorical(
logits=distrib.logits.clone().detach(),
)
elif isinstance(distrib, torch.distributions.Normal):
return torch.distributions.Normal(
loc=distrib.loc.clone().detach(), scale=distrib.scale.clone().detach(),
loc=distrib.loc.clone().detach(),
scale=distrib.scale.clone().detach(),
)
else:
raise NotImplementedError("{} is not supported by ACER".format(type(distrib)))
Expand Down Expand Up @@ -624,7 +628,9 @@ def update_from_replay(self):
(avg_action_distrib, _, _),
shared_recurrent_state,
) = one_step_forward(
self.shared_average_model, bs, shared_recurrent_state,
self.shared_average_model,
bs,
shared_recurrent_state,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(bs)
Expand Down Expand Up @@ -731,7 +737,9 @@ def _act_train(self, obs):
(avg_action_distrib, _, _),
self.shared_recurrent_states,
) = one_step_forward(
self.shared_average_model, statevar, self.shared_recurrent_states,
self.shared_average_model,
statevar,
self.shared_recurrent_states,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(statevar)
Expand Down Expand Up @@ -789,7 +797,10 @@ def _observe_train(self, state, reward, done, reset):
self.past_rewards[self.t - 1] = reward
if self.process_idx == 0:
self.logger.debug(
"t:%s r:%s a:%s", self.t, reward, self.last_action,
"t:%s r:%s a:%s",
self.t,
reward,
self.last_action,
)

if self.t - self.t_start == self.t_max or done or reset:
Expand Down
12 changes: 9 additions & 3 deletions pfrl/agents/al.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def _compute_y_and_t(self, exp_batch):

if self.recurrent:
qout, _ = pack_and_forward(
self.model, batch_state, exp_batch["recurrent_state"],
self.model,
batch_state,
exp_batch["recurrent_state"],
)
else:
qout = self.model(batch_state)
Expand All @@ -42,7 +44,9 @@ def _compute_y_and_t(self, exp_batch):
with torch.no_grad():
if self.recurrent:
target_qout, _ = pack_and_forward(
self.target_model, batch_state, exp_batch["recurrent_state"],
self.target_model,
batch_state,
exp_batch["recurrent_state"],
)
target_next_qout, _ = pack_and_forward(
self.target_model,
Expand All @@ -53,7 +57,9 @@ def _compute_y_and_t(self, exp_batch):
target_qout = self.target_model(batch_state)
target_next_qout = self.target_model(batch_next_state)

next_q_max = target_next_qout.max.reshape(batch_size,)
next_q_max = target_next_qout.max.reshape(
batch_size,
)

batch_rewards = exp_batch["reward"]
batch_terminal = exp_batch["is_state_terminal"]
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/categorical_double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def _compute_target_values(self, exp_batch):
exp_batch["next_recurrent_state"],
)
next_qout, _ = pack_and_forward(
self.model, batch_next_state, exp_batch["next_recurrent_state"],
self.model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
target_next_qout = self.target_model(batch_next_state)
Expand Down
4 changes: 3 additions & 1 deletion pfrl/agents/categorical_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch["next_state"]
if self.recurrent:
target_next_qout, _ = pack_and_forward(
self.target_model, batch_next_state, exp_batch["next_recurrent_state"],
self.target_model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
target_next_qout = self.target_model(batch_next_state)
Expand Down

0 comments on commit 03a17d6

Please sign in to comment.