;
(use-package :RLI)

(defun complete-example ()
  (let (agent env sim)
  (setq agent (make-instance 'my-Q-agent :alpha .01 :gamma .99))
  (setq env (make-instance 'maintenance-task))
  (setq sim (make-instance 'my-simulation))
  (sim-init sim agent env)
  (sim-steps sim 10000)))

(defclass MY-SIMULATION (simulation)
  ((sum-rewards :initform 0.0)
   (num-rewards :initform 0)
   (sum-interval :initform 1000 :initarg :sum-interval)))

(defmethod sim-collect-data ((sim my-simulation) s a s-prime r)
  (declare (ignore s a s-prime))
  (with-slots (sum-rewards num-rewards sum-interval) sim
    (incf sum-rewards r)
    (incf num-rewards)
    (when (= num-rewards sum-interval)
      (format t "~%Average reward over ~A steps: ~A"
              sum-interval (/ sum-rewards num-rewards))
      (setq sum-rewards 0.0)
      (setq num-rewards 0))))


(defclass my-Q-AGENT 
  (Q-table tabular-1step-Q-learning egreedy-policy agent)
  (last-sensation last-action))

(defmethod agent-start-episode ((agent my-Q-AGENT) sensation)
  (with-slots (last-sensation last-action) agent
    (setq last-sensation sensation)
    (setq last-action (policy agent sensation))))


(defclass Q-TABLE ()
  ((num-states  :accessor num-states :initarg :num-states)
   (num-actions :accessor num-actions :initarg :num-actions)
   (initial-value :accessor :initarg :initform 0)
   (Q)))

(defmethod agent-init :before ((agent Q-table))
  (with-slots (Q num-states num-actions initial-value) agent
    (setq num-states (num-states (agent-env agent)))
    (setq num-actions (num-actions (agent-env agent)))
    (setf Q (make-array (list num-states num-actions)
                        :initial-element initial-value))))

(defmethod action-values ((agent Q-table) s)
  (with-slots (Q num-actions) agent
    (loop for a below num-actions
          collect (aref Q s a))))

(defmethod state-value ((agent Q-table) s)
  (if (eq s :terminal-state)
    0
    (apply #'max (action-values agent s))))


(defclass TABULAR-1STEP-Q-LEARNING ()
  ((alpha :initarg :alpha :initform 0.1)
   (gamma :accessor gamma :initarg :gamma :initform .9)))

(defmethod agent-step ((agent tabular-1step-Q-learning) s-prime r)
  (with-slots (Q alpha gamma last-sensation last-action) agent
    (let ((s last-sensation)
          (a last-action))
      (incf (aref Q s a)
            (* alpha 
               (+ r 
                  (* gamma (state-value agent s-prime))
                  (- (aref Q s a)))))
      (setq last-sensation s-prime)
      (setq last-action (policy agent s-prime)))))


(defclass EGREEDY-POLICY ()
  ((epsilon :accessor agent-epsilon :initarg :epsilon :initform 0)))

(defmethod policy ((agent egreedy-policy) state)
  (with-slots (epsilon num-actions) agent
    (with-prob epsilon
      (random num-actions)
      (arg-max-random-tiebreak (action-values agent state)))))

(defun arg-max-random-tiebreak (list)
  "Returns an index to the largest value in the non-null list"
  (loop with best-args = (list 0)
        with best-value = (first list)
        for i from 1
        for value in (rest list) 
        do (cond ((< value best-value))
                 ((> value best-value)
                  (setq best-value value)
                  (setq best-args (list i)))
                 ((= value best-value)
                  (push i best-args)))
        finally (return (values (nth (random (length best-args))
                                     best-args)
                                best-value))))


;;; TABULAR ENVIRONMENTS

(defclass finite-MDP (environment)
  ((num-states :initarg :num-states :accessor num-states)
   (num-actions :initarg :num-actions :accessor num-actions)
   (state :initarg :initial-state :initform 0)))

(defmethod env-start-episode ((env finite-MDP))
  (with-slots (state) env
    (setq state 0)))

(defmethod env-step ((env finite-MDP) action)
  (with-slots (state) env
    (let ((old-state state))
      (values (setq state (env-next-state env state action))
              (env-next-reward env old-state state action)))))
   
;; This completes the lowest-level specification of a Markov decision task.
;; It is something which can return samples of the next state and reward.


;; Next we have an example MDP and agent: the maintenance task.  This is a continual
;; task, with no episodes or resets.  You are running a machine to maximize reward.
;; The only way to get rewards is to operate your machine.  If that works well,
;; you earn $1 and proceed from your current state i to i+1.  But the machine 
;; also might break.  Then you go to a broken state (with zero reward), where you
;; stay with probability Q until you get out to state 0.  If you chose not to 
;; operate the machine, you can instead do maintenance.  This doesn't get any 
;; reward, and takes you back to state 0.  The probability that operating the 
;; machine will work is P^(i+1), except for state N, which never works.  
;; Counting the broken state, there are N+2 states

(defclass maintenance-task (finite-MDP)
  ((N :initform 10 :initarg :N)
   (P :initform .9 :initarg :P)
   (Q :initform .9 :initarg :Q)
   (num-actions :initform 2)))

(defmethod env-init :after ((env maintenance-task))
  (with-slots (N num-states) env
    (setf num-states (+ N 2))))

(defmethod env-next-state ((env maintenance-task) x a)
  (with-slots (N P Q) env
    (cond ((= x (+ n 1))                  ; already broken
           (with-prob Q x 0))             ; with prob Q stay, else reset to 0
          ((= a 0)                        ; maintenance action
           0)                             ;  always causes reset
          ((= x n)                        ; but if in final state
           (+ n 1))                       ; then must fail
          (t (with-prob (expt P (+ x 1))  ; otherwise take your chances
               (+ x 1)                    ;  to get to next state
               (+ n 1))))))               ;  or break!

(defmethod env-next-reward ((env maintenance-task) x y a)
  (declare (ignore x a))
  (with-slots (N) env
    (if (< 0 y  (+ n 1))
      1
      0)))


(defun with-prob (p x y &optional (random-state *random-state*))
  (if (< (random 1.0 random-state) p)
      x
      y))
                                                                                        ;