;-*- Package: (cl-user) -*-
;;; The classic random walk, solved by TD methods.
;;; This version is for 1) off-line n-step methods, 2) conventional off-line TD(lambda)
;;; with accumulating traces (which is equivalent at episodes' ends to the off-line
;;; lambda-return algorithm), and 3) the offline lambda-return algorithm
;;; The code and calls for some of the figures in the RL text are given at the very end.
(defvar NN 1) ; the n of n-step methods
(defvar n 19) ; the number of nonterminal states
(defvar w) ; the vector of weights = predictions
(defvar delta-w)
(defvar e) ; the eligibility trace
(defvar lambda .9) ; trace decay parameter
(defvar alpha 0.1) ; learning-rate parameter
(defvar standard-walks nil) ; list of standard walks
(defvar targets) ; the correct predictions
(defvar right-outcome 1.0)
(defvar left-outcome -1.0)
(defvar initial-w 0.0)
(defun setup (num-runs num-walks)
(setq w (make-array n))
(setq delta-w (make-array n))
(setq e (make-array n))
(setq standard-walks (standard-walks num-runs num-walks))
(length standard-walks))
(defun init ()
(loop for i below n do (setf (aref w i) initial-w))
(setq targets
(loop for i below n collect
(+ (* (- right-outcome left-outcome)
(/ (+ i 1) (+ n 1)))
left-outcome))))
(defun init-traces ()
(loop for i below n do (setf (aref e i) 0)))
(defun learn-fv (x target)
(incf (aref delta-w x) (* alpha (- target (aref w x)))))
(defun learn-TDlambda (x target)
(loop for i below n do (setf (aref e i) (* lambda (aref e i))))
(incf (aref e x) 1)
(loop for i below n
with error = (- target (aref w x))
do (incf (aref delta-w i) (* alpha error (aref e i)))))
(defun process-walk-TDlambda (walk)
(destructuring-bind (outcome states) walk
(loop for i below n do (setf (aref delta-w i) 0))
(init-traces)
(loop for s1 in states
for s2 in (rest states)
do (learn-TDlambda s1 (aref w s2)))
(learn-TDlambda (first (last states)) outcome)
(loop for i below n do (incf (aref w i) (aref delta-w i)))))
(defun process-walk-nstep (walk)
(destructuring-bind (outcome states) walk
(loop for i below n do (setf (aref delta-w i) 0))
(loop for s1 in states
for rest on states
do (learn-fv s1 (if (>= NN (length rest))
outcome
(aref w (nth NN rest)))))
(loop for i below n do (incf (aref w i) (aref delta-w i)))))
(defun process-walk-lambda-return (walk)
(destructuring-bind (outcome states) walk
(loop for i below n do (setf (aref delta-w i) 0))
(loop for s1 in states
for rest on states
for target = (+ (* (- 1 lambda)
(loop for St+n in (cdr rest)
for ln = 1 then (* ln lambda)
sum (* ln (aref w St+n))))
(* (expt lambda (length (cdr rest)))
outcome))
do (learn-fv s1 target))
(loop for i below n do (incf (aref w i) (aref delta-w i)))))
(defun standard-walks (num-sets-of-walks num-walks)
(loop repeat num-sets-of-walks
with random-state = (ut::copy-of-standard-random-state)
collect (loop repeat num-walks
collect (random-walk n random-state))))
(defun random-walk (n &optional (random-state *random-state*))
(loop with start-state = (truncate (/ n 2))
for x = start-state then (with-prob .5 (+ x 1) (- x 1) random-state)
while (AND (>= x 0) (< x n))
collect x into xs
finally (return (list (if (< x 0) -1 1) xs))))
(defun residual-error ()
"Returns the residual RMSE between the current and correct predictions"
(rmse 0 (loop for w-i across w
for target-i in targets
collect (- w-i target-i))))
(defun learning-curve-TDlambda (alpha-arg lambda-arg)
(setq alpha alpha-arg)
(setq lambda lambda-arg)
(multi-mean
(loop for walk-set in standard-walks
do (init)
collect (cons (residual-error)
(loop for walk in walk-set
do (process-walk-TDlambda walk)
collect (residual-error))))))
(defun learning-curve-nstep (alpha-arg NN-arg)
(setq alpha alpha-arg)
(setq NN NN-arg)
(multi-mean
(loop for walk-set in standard-walks
do (init)
collect (cons (residual-error)
(loop for walk in walk-set
do (process-walk-nstep walk)
collect (residual-error))))))
(defun learning-curve-lambda-return (alpha-arg lambda-arg)
(setq alpha alpha-arg)
(setq lambda lambda-arg)
(multi-mean
(loop for walk-set in standard-walks
do (init)
collect (cons (residual-error)
(loop for walk in walk-set
do (process-walk-lambda-return walk)
collect (residual-error))))))
(defun offline-TDlambda ()
"summary results as fn of alpha, for various lambda"
(graph (loop for lambda in '(0 .4 .8 .9 .95 .975 .99 1) collect
(cons (progn (init) (list 0 (residual-error)))
(loop for alog from -5 to -1 by 0.1
for alpha = (exp alog)
for error = (mean (rest (learning-curve-TDlambda alpha lambda)))
while (< error 10)
collect (list alpha error)))))
(y-tick-marks .35 .4 .45 .5 .55)
(x-tick-marks 0 .1 .2 .3)
(y-graph-limits .33 .55)
(x-graph-limits 0 .3))
(defun offline-nstep ()
"summary results as fn of alpha, for various n"
(graph (loop for nstep = 1 then (* nstep 2) while (< nstep 1000) collect
(cons (progn (init) (list 0 (residual-error)))
(loop for alog from -5 to 0 by 0.1
for alpha = (exp alog)
for error = (mean (rest (learning-curve-nstep alpha nstep)))
while (< error 10)
collect (list alpha error)))))
(y-tick-marks .25 .3 .35 .4 .45 .5 .55)
(x-tick-marks 0 .1 .2 .3)
(y-graph-limits .25 .55)
(x-graph-limits 0 .3))
(defun offline-lambda-return ()
"summary results as fn of alpha, for various lambda"
(graph (loop for lambda in '(0 .4 .8 .9 .95 .975 .99 1) collect
(cons (progn (init) (list 0 (residual-error)))
(loop for alog from -5 to -1 by 0.1
for alpha = (exp alog)
for error = (mean (rest (learning-curve-lambda-return alpha lambda)))
while (< error 10)
collect (list alpha error)))))
(graph+ (cons (progn (init) (list 0 (residual-error)))
(loop for alog from -5 to 0 by 0.1
for alpha = (exp alog)
for error = (mean (rest (learning-curve-nstep alpha 3)))
while (< error 10)
collect (list alpha error))))
(y-tick-marks .25 .3 .35 .4 .45 .5 .55)
(x-tick-marks 0 .1 .2 .3)
(y-graph-limits .25 .55)
(x-graph-limits 0 .3))
#|
(setup 100 10)
(offline-nstep) 1st edition, Figure 7.2, lower graph
(offline-TDlambda) 1st edition, Figure 7.6 (how actually made)
(offline-lambda-return) 1st edition, Figure 7.6 (ostensibly)
(y-tick-marks .25 .3 .35 .4 .45 .5 .55)
(x-tick-marks 0 .2 .4 .6 .8 1)
(y-graph-limits .25 .55)
(x-graph-limits 0 1)
|#