This repository was archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathbaselines.patch
More file actions
71 lines (65 loc) · 2.51 KB
/
baselines.patch
File metadata and controls
71 lines (65 loc) · 2.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
diff --git a/baselines/bench/monitor.py b/baselines/bench/monitor.py
index 0db473a..f2d93ac 100644
--- a/baselines/bench/monitor.py
+++ b/baselines/bench/monitor.py
@@ -76,6 +76,7 @@ class Monitor(Wrapper):
self.total_steps += 1
def close(self):
+ super().close()
if self.f is not None:
self.f.close()
diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py
index 60db11d..387ea11 100644
--- a/baselines/common/vec_env/dummy_vec_env.py
+++ b/baselines/common/vec_env/dummy_vec_env.py
@@ -61,6 +61,13 @@ class DummyVecEnv(VecEnv):
self._save_obs(e, obs)
return self._obs_from_buf()
+ def close_extras(self):
+ """
+ Clean up the extra resources, beyond what's in this base class.
+ Only runs when not self.closed().
+ """
+ self.envs[0].close()
+
def _save_obs(self, e, obs):
for k in self.keys:
if k is None:
diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py
index 4dc4d2c..a1ec19c 100644
--- a/baselines/common/vec_env/subproc_vec_env.py
+++ b/baselines/common/vec_env/subproc_vec_env.py
@@ -70,13 +70,29 @@ class SubprocVecEnv(VecEnv):
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
- return np.stack(obs), np.stack(rews), np.stack(dones), infos
+ if isinstance(obs[0], dict):
+ obs_output = {
+ key: np.stack([obs_[key] for obs_ in obs])
+ for key in obs[0].keys()
+ }
+ else:
+ obs_output = np.stack(obs)
+ return obs_output, np.stack(rews), np.stack(dones), infos
def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(('reset', None))
- return np.stack([remote.recv() for remote in self.remotes])
+
+ obs = [remote.recv() for remote in self.remotes]
+ if isinstance(obs[0], dict):
+ obs_output = {
+ key: np.stack([obs_[key] for obs_ in obs])
+ for key in obs[0].keys()
+ }
+ else:
+ obs_output = np.stack(obs)
+ return obs_output
def close_extras(self):
self.closed = True