;-*- Package: (discrete-walk) -*-
;;; A simulation of a TD(lambda) learning system to predict the expected outcome
;;; of a discrete-state random walk like that in the original 1988 TD paper.
(defpackage :discrete-walk
(:use :common-lisp :g :ut :graph)
(:nicknames :dwalk))
(in-package :dwalk)
(defvar n 5) ; the number of nonterminal states
(defvar w) ; the vector of weights = predictions
(defvar e) ; the eligibility trace
(defvar lambda .9) ; trace decay parameter
(defvar alpha 0.1) ; learning-rate parameter
(defvar initial-w 0.5)
(defvar standard-walks nil) ; list of standard walks
(defvar trace-type :none) ; :replace, :accumulate, :average, :1/t or :none
(defvar alpha-type :fixed) ; :fixed, :1/t, or :1/t-max
(defvar alpha-array) ; used when each state has a different alpha
(defvar u) ; usage count = number of times updated
(defun setup (num-runs num-walks)
(setq w (make-array n))
(setq e (make-array n))
(setq u (make-array n))
(setq alpha-array (make-array n))
(setq standard-walks (standard-walks num-runs num-walks))
(length standard-walks))
(defun init ()
(loop for i below n do (setf (aref w i) initial-w))
(loop for i below n do (setf (aref alpha-array i) alpha))
(loop for i below n do (setf (aref u i) 0)))
(defun init-traces ()
(loop for i below n do (setf (aref e i) 0)))
(defun learn (x target)
(ecase alpha-type
(:1/t (incf (aref u x))
(setf (aref alpha-array x) (/ 1.0 (aref u x))))
(:fixed)
(:1/t-max (when (<= (aref u x) (/ 1 alpha))
(incf (aref u x))
(setf (aref alpha-array x) (/ 1.0 (aref u x))))))
(ecase trace-type
(:none)
(:replace (loop for i below n do (setf (aref e i) (* lambda (aref e i))))
(decf (aref u x) (aref e x))
(setf (aref e x) 1))
(:accumulate (loop for i below n do (setf (aref e i) (* lambda (aref e i))))
(incf (aref e x) 1))
(:average (loop for i below n do (setf (aref e i) (* lambda (aref e i))))
(setf (aref e x) (+ 1 (* (aref e x) (- 1 (aref alpha-array x))))))
(:1/t (incf (aref u x))
(incf (aref e x) 1)
(loop for i below n
for lambda = (float (/ (aref u x)))
do (setf (aref e i) (* lambda (aref e i))))))
(if (eq trace-type :none)
(incf (aref w x) (* (aref alpha-array x) (- target (aref w x))))
(loop for i below n
with error = (- target (aref w x))
do (incf (aref w i) (* (aref alpha-array i) error (aref e i))))))
(defun process-walk (walk)
(destructuring-bind (outcome states) walk
(unless (eq trace-type :none) (init-traces))
(loop for s1 in states
for s2 in (rest states)
do (learn s1 (aref w s2)))
(learn (first (last states)) outcome)))
(defun process-walk-backwards (walk)
(destructuring-bind (outcome states) walk
(unless (eq trace-type :none) (init-traces))
(learn (first (last states)) outcome)
(loop for s1 in (reverse (butlast states))
for s2 in (reverse (rest states))
do (learn s1 (aref w s2)))))
(defun process-walk-MC (walk)
(destructuring-bind (outcome states) walk
(loop for s in (reverse states)
do (learn s outcome))))
(defun standard-walks (num-sets-of-walks num-walks)
(loop repeat num-sets-of-walks
with random-state = (ut::copy-of-standard-random-state)
collect (loop repeat num-walks
collect (random-walk n random-state))))
(defun random-walk (n &optional (random-state *random-state*))
(loop with start-state = (round (/ n 2))
for x = start-state then (with-prob .5 (+ x 1) (- x 1) random-state)
while (AND (>= x 0) (< x n))
collect x into xs
finally (return (list (if (< x 0) 0 1) xs))))
(defun residual-error ()
"Returns the residual RMSE between the current and correct predictions"
(rmse 0 (loop for i below n
when (>= (aref w i) -.1)
collect (- (aref w i)
(/ (+ i 1) (+ n 1) )))))
(defun explore (alpha-type-arg alpha-arg lambda-arg trace-type-arg forward?
&optional (number-type 'float))
(setq alpha-type alpha-type-arg)
(setq alpha alpha-arg)
(setq lambda lambda-arg)
(setq lambda (coerce lambda number-type))
(setq alpha (coerce alpha number-type))
(setq trace-type trace-type-arg)
(record (stats (loop for walk-set in standard-walks
do (init)
do (loop repeat 100 do (loop for walk in walk-set do (if forward?
(process-walk walk)
(process-walk-backwards walk))))
collect (residual-error)))))
(defun learning-curve (alpha-type-arg alpha-arg lambda-arg trace-type-arg
&optional (processing :forward) (initial-w-arg 0.5)
(number-type 'float))
(setq alpha-type alpha-type-arg)
(setq alpha alpha-arg)
(setq lambda lambda-arg)
(setq lambda (coerce lambda number-type))
(setq alpha (coerce alpha number-type))
(setq trace-type trace-type-arg)
(setq initial-w initial-w-arg)
(multi-mean
(loop for walk-set in standard-walks
do (init)
collect (cons (residual-error)
(loop for walk in walk-set
do (ecase processing
(:forward (process-walk walk))
(:backward (process-walk-backwards walk))
(:MC (process-walk-MC walk)))
collect (residual-error))))))
(defun batch-learning-curve-TD ()
(setq alpha 0.01)
(setq lambda 0.0)
(setq trace-type :none)
(setq initial-w -1)
(multi-mean
(loop with last-w = (make-array n)
for walk-set in standard-walks
do (init)
collect (loop for num-walks from 1 to (length walk-set)
for walk-subset = (firstn num-walks walk-set) do
(loop do (loop for i below n do (setf (aref last-w i) (aref w i)))
do (loop for walk in walk-subset
do (process-walk walk))
until (> .0000001 (loop for i below n
sum (abs (- (aref w i) (aref last-w i))))))
collect (residual-error)))))