;-*- Package: (cl-user) -*- ;;; The classic random walk, solved by TD methods. ;;; This version is for 1) off-line n-step methods, 2) conventional off-line TD(lambda) ;;; with accumulating traces (which is equivalent at episodes' ends to the off-line ;;; lambda-return algorithm), and 3) the offline lambda-return algorithm ;;; The code and calls for some of the figures in the RL text are given at the very end. (defvar NN 1) ; the n of n-step methods (defvar n 19) ; the number of nonterminal states (defvar w) ; the vector of weights = predictions (defvar delta-w) (defvar e) ; the eligibility trace (defvar lambda .9) ; trace decay parameter (defvar alpha 0.1) ; learning-rate parameter (defvar standard-walks nil) ; list of standard walks (defvar targets) ; the correct predictions (defvar right-outcome 1.0) (defvar left-outcome -1.0) (defvar initial-w 0.0) (defun setup (num-runs num-walks) (setq w (make-array n)) (setq delta-w (make-array n)) (setq e (make-array n)) (setq standard-walks (standard-walks num-runs num-walks)) (length standard-walks)) (defun init () (loop for i below n do (setf (aref w i) initial-w)) (setq targets (loop for i below n collect (+ (* (- right-outcome left-outcome) (/ (+ i 1) (+ n 1))) left-outcome)))) (defun init-traces () (loop for i below n do (setf (aref e i) 0))) (defun learn-fv (x target) (incf (aref delta-w x) (* alpha (- target (aref w x))))) (defun learn-TDlambda (x target) (loop for i below n do (setf (aref e i) (* lambda (aref e i)))) (incf (aref e x) 1) (loop for i below n with error = (- target (aref w x)) do (incf (aref delta-w i) (* alpha error (aref e i))))) (defun process-walk-TDlambda (walk) (destructuring-bind (outcome states) walk (loop for i below n do (setf (aref delta-w i) 0)) (init-traces) (loop for s1 in states for s2 in (rest states) do (learn-TDlambda s1 (aref w s2))) (learn-TDlambda (first (last states)) outcome) (loop for i below n do (incf (aref w i) (aref delta-w i))))) (defun process-walk-nstep (walk) (destructuring-bind (outcome states) walk (loop for i below n do (setf (aref delta-w i) 0)) (loop for s1 in states for rest on states do (learn-fv s1 (if (>= NN (length rest)) outcome (aref w (nth NN rest))))) (loop for i below n do (incf (aref w i) (aref delta-w i))))) (defun process-walk-lambda-return (walk) (destructuring-bind (outcome states) walk (loop for i below n do (setf (aref delta-w i) 0)) (loop for s1 in states for rest on states for target = (+ (* (- 1 lambda) (loop for St+n in (cdr rest) for ln = 1 then (* ln lambda) sum (* ln (aref w St+n)))) (* (expt lambda (length (cdr rest))) outcome)) do (learn-fv s1 target)) (loop for i below n do (incf (aref w i) (aref delta-w i))))) (defun standard-walks (num-sets-of-walks num-walks) (loop repeat num-sets-of-walks with random-state = (ut::copy-of-standard-random-state) collect (loop repeat num-walks collect (random-walk n random-state)))) (defun random-walk (n &optional (random-state *random-state*)) (loop with start-state = (truncate (/ n 2)) for x = start-state then (with-prob .5 (+ x 1) (- x 1) random-state) while (AND (>= x 0) (< x n)) collect x into xs finally (return (list (if (< x 0) -1 1) xs)))) (defun residual-error () "Returns the residual RMSE between the current and correct predictions" (rmse 0 (loop for w-i across w for target-i in targets collect (- w-i target-i)))) (defun learning-curve-TDlambda (alpha-arg lambda-arg) (setq alpha alpha-arg) (setq lambda lambda-arg) (multi-mean (loop for walk-set in standard-walks do (init) collect (cons (residual-error) (loop for walk in walk-set do (process-walk-TDlambda walk) collect (residual-error)))))) (defun learning-curve-nstep (alpha-arg NN-arg) (setq alpha alpha-arg) (setq NN NN-arg) (multi-mean (loop for walk-set in standard-walks do (init) collect (cons (residual-error) (loop for walk in walk-set do (process-walk-nstep walk) collect (residual-error)))))) (defun learning-curve-lambda-return (alpha-arg lambda-arg) (setq alpha alpha-arg) (setq lambda lambda-arg) (multi-mean (loop for walk-set in standard-walks do (init) collect (cons (residual-error) (loop for walk in walk-set do (process-walk-lambda-return walk) collect (residual-error)))))) (defun offline-TDlambda () "summary results as fn of alpha, for various lambda" (graph (loop for lambda in '(0 .4 .8 .9 .95 .975 .99 1) collect (cons (progn (init) (list 0 (residual-error))) (loop for alog from -5 to -1 by 0.1 for alpha = (exp alog) for error = (mean (rest (learning-curve-TDlambda alpha lambda))) while (< error 10) collect (list alpha error))))) (y-tick-marks .35 .4 .45 .5 .55) (x-tick-marks 0 .1 .2 .3) (y-graph-limits .33 .55) (x-graph-limits 0 .3)) (defun offline-nstep () "summary results as fn of alpha, for various n" (graph (loop for nstep = 1 then (* nstep 2) while (< nstep 1000) collect (cons (progn (init) (list 0 (residual-error))) (loop for alog from -5 to 0 by 0.1 for alpha = (exp alog) for error = (mean (rest (learning-curve-nstep alpha nstep))) while (< error 10) collect (list alpha error))))) (y-tick-marks .25 .3 .35 .4 .45 .5 .55) (x-tick-marks 0 .1 .2 .3) (y-graph-limits .25 .55) (x-graph-limits 0 .3)) (defun offline-lambda-return () "summary results as fn of alpha, for various lambda" (graph (loop for lambda in '(0 .4 .8 .9 .95 .975 .99 1) collect (cons (progn (init) (list 0 (residual-error))) (loop for alog from -5 to -1 by 0.1 for alpha = (exp alog) for error = (mean (rest (learning-curve-lambda-return alpha lambda))) while (< error 10) collect (list alpha error))))) (graph+ (cons (progn (init) (list 0 (residual-error))) (loop for alog from -5 to 0 by 0.1 for alpha = (exp alog) for error = (mean (rest (learning-curve-nstep alpha 3))) while (< error 10) collect (list alpha error)))) (y-tick-marks .25 .3 .35 .4 .45 .5 .55) (x-tick-marks 0 .1 .2 .3) (y-graph-limits .25 .55) (x-graph-limits 0 .3)) #| (setup 100 10) (offline-nstep) 1st edition, Figure 7.2, lower graph (offline-TDlambda) 1st edition, Figure 7.6 (how actually made) (offline-lambda-return) 1st edition, Figure 7.6 (ostensibly) (y-tick-marks .25 .3 .35 .4 .45 .5 .55) (x-tick-marks 0 .2 .4 .6 .8 1) (y-graph-limits .25 .55) (x-graph-limits 0 1) |#