Skip to content
Snippets Groups Projects
Commit f45eb920 authored by tosatto's avatar tosatto
Browse files

exploit policy

parent d01716dc
Branches
No related tags found
No related merge requests found
......@@ -18,13 +18,60 @@ class BootPolicy(TDPolicy):
def draw_action(self, state):
#self._idx = np.random.randint(self._n_approximators)
if not np.random.uniform() < self._epsilon(state):
if self._eval:
qs = [self._approximator.predict(state, idx=i) for i in range(self._n_approximators)]
max_as, count = np.unique(np.argmax(qs, axis=1), return_counts=True)
max_a = np.array([max_as[np.random.choice(np.argwhere(count == np.max(count)).ravel())]])
return max_a
else:
q = self._approximator.predict(state, idx=self._idx)
max_a = np.argwhere(q == np.max(q)).ravel()
if len(max_a) > 1:
max_a = np.array([np.random.choice(max_a)])
return max_a
return np.array([np.random.choice(self._approximator.n_actions)])
def set_eval(self, value):
self._eval = value
def set_epsilon(self, epsilon):
self._epsilon = epsilon
def set_idx(self, idx):
self._idx = idx
class ExploitPolicy(TDPolicy):
def __init__(self, n_approximators, epsilon=None):
if epsilon is None:
epsilon = Parameter(0.)
super(ExploitPolicy, self).__init__()
self._n_approximators = n_approximators
self._epsilon = epsilon
self._idx = None
self._eval = False
def draw_action(self, state):
#self._idx = np.random.randint(self._n_approximators)
if not np.random.uniform() < self._epsilon(state):
if self._eval:
q = self._approximator.predict(state, idx=self._n_approximators - 1)
max_a = np.argwhere(q == np.max(q)).ravel()
if len(max_a) > 1:
max_a = np.array([np.random.choice(max_a)])
return max_a
else:
q = self._approximator.predict(state, idx=self._idx)
max_a = np.argwhere(q == np.max(q)).ravel()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment