from scipy import *
import pylab
from matplotlib.patches import Rectangle
import time

from maze import QLearning

#############################################################
# QLambda
# implements the Q-Learning algorithm with eligibility 
# traces on a maze task.
# by: Thomas Rueckstiess, WS 09/10 - Machine Learning I, TUM
#############################################################
class QLambda(QLearning):
    
    # eligibility traces for all state-action pairs
    etrace = zeros((8, 8, 4))   
    lambda_ = 0
    
    # visualization of traces
    traces = None

    def __init__(self, lambda_):
        # save lambda in class variable (careful: "lambda" is a python keyword)
        self.lambda_ = lambda_
        
        # figure size 
        self.figure = pylab.figure(figsize=(12, 4))
                
        # turn on "hold" so successive plotting commands plot in the same window
        pylab.hold(True)


        # reduce learning rate (more stable with traces)
        self.alpha = 0.1
        
        # create subplots, one for the maze/agent...
        pylab.subplot(131)
        pylab.title('agent in maze')    
        self.manager = pylab.get_current_fig_manager()
        pylab.imshow(-self.maze, cmap=pylab.cm.gray, interpolation='nearest')
        self.rectangle = Rectangle(xy=(self.state[1]-0.25, self.state[0]-0.25), width=0.5, height=0.5, facecolor='red')
        pylab.gca().add_artist(self.rectangle)

        # ...and one for the value function
        pylab.subplot(132)
        pylab.title('value function')
        self.image = pylab.imshow(self.qtable.max(axis=2), cmap=pylab.cm.hot, interpolation='nearest')

        # ...and another one for the traces
        pylab.subplot(133)
        pylab.title('eligibility traces')
        self.traces = pylab.imshow(self.etrace.max(axis=2), cmap=pylab.cm.hot, interpolation='nearest')


    def updateplot(self):
        """ update the graphical elements for visualization. """
        QLearning.updateplot(self)
        
        # update traces visualization
        self.traces.set_array(self.etrace.max(axis=2))
        self.traces.autoscale()

    
    def iteration(self):
        """ this is the main loop of the algorithm. it calls agent, environment
            updates the Q-table and calls the plot method. """
    
        # execute agent and environment parts of iteration
        self.agent(self.state)
        while not self.environment(self.action):
            # loop until agent delivers a valid action (that does not hit the wall)
            self.agent(self.state)
        
        # update eligibility trace values (decay + visited state)
        self.etrace *= self.gamma*self.lambda_
        self.etrace[self.state[0], self.state[1], self.action] += 1
        
        # update value function
        if self.history:
            laststate, lastaction, lastreward = self.history
            # update Q-values
            delta = (self.reward + self.gamma * max(self.qtable[self.state[0], self.state[1], :]) - self.qtable[laststate[0], laststate[1], lastaction])
            for s1 in range(8):
                for s2 in range(8):
                    for a in range(4):
                        self.qtable[s1, s2, a] += self.alpha * delta * self.etrace[s1, s2, a]
        
        # write s,a,r to history
        self.history = (self.state, self.action, self.reward)
        
        # update the visualization
        self.updateplot()
                
        # reset agent if goal reached
        if all(self.goal == self.state):
            self.state = self.start
            self.history = None
            self.etrace = zeros((8, 8, 4))
            
        # reduce epsilon
        self.epsilon *= self.epsilondecay       
        return True


### main function ###################################################   
if __name__ == '__main__':
    qlambda = QLambda(0.7)
    while (True):
            qlambda.iteration()
    # gobject.idle_add(qlambda.iteration)
    pylab.show()
