; Copyright (c) 2024-2026 Diogo Behrens
;
; Permission to use, copy, modify, and/or distribute this software for any
; purpose with or without fee is hereby granted.
;
; THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
; REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
; AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
; INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
; LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
; OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
; PERFORMANCE OF THIS SOFTWARE.

(import (scheme base)
        (scheme write)
        (srfi 1)
        (srfi 69))

(define-syntax print
  (syntax-rules ()
    ((_ ARG ...)
     (begin
       (display ARG) ...
       (newline)))))

;; -----------------------------------------------------------------------------
;; atomic block
;; -----------------------------------------------------------------------------

(define *in-atomic-block* (make-parameter #f))

(define-syntax atomic-block
  (syntax-rules ()
    ((_ SEXP ...)
     (parameterize ((*in-atomic-block* #t))
       SEXP ...))))

(define (in-atomic-block?)
  (*in-atomic-block*))

;; ----------------------------------------------------------------------------
;; DSL
;; ----------------------------------------------------------------------------

;;; Cooperative yield point.
;;; If not in an atomic block, suspends the current task and resumes another.
(define (YIELD)
  (unless (in-atomic-block?)
    (yield)))

;;; Spawn a new task to run `thunk` (a 0-argument procedure).
(define (SPAWN thunk)
  (spawn thunk))

;;; Explore all executions (subject to pruning).
;;; `postconds` are 0-argument predicates checked after each execution.
(define (EXPLORE . postconds)
  (explore postconds))

;;; Initialize a shared memory location `loc` with value `val`.
;;; Typically called before spawning/exploring; re-initialization is an error.
(define (INIT loc val)
  (init-location loc val))

;;; Write value `val` to location `loc`.
;;; Yields before the write unless inside an atomic block.
;;; The location must be initialized before being written.
(define (WRITE loc val)
  (YIELD)
  (print "WRITE " loc " " val)
  (write-location loc val))

;;; Read and return the value of location `loc`.
;;; Yields before the read unless inside an atomic block.
;;; The location must be initialized before being read.
(define (READ loc)
  (YIELD)
  (let ((val (read-location loc)))
    (print "READ " loc " " val)
    val))

;;; Atomic block (indivisible w.r.t. other tasks).
;;; Performs an initial yield, then runs the body with yields suppressed.
(define-syntax ATOMIC
  (syntax-rules ()
    ((_ SEXP ...)
     (begin
       (YIELD)
       (atomic-block SEXP ...)))))

;;; Non-deterministically pick one of `options` (used to model choices).
(define (CHOOSE . options)
  (choose options))

;;; Assertion helper for postconditions / task code.
(define (ASSERT cnd msg)
  (unless cnd
    (error "ASSERT" msg)))

;; -----------------------------------------------------------------------------
;; taskqueue
;; -----------------------------------------------------------------------------

(define *tasks* '())

(define (no-tasks?)
  (null? *tasks*))

(define (push-task task)
  (when (memq task *tasks*)
    (error "push-task: duplicate continuation"))
  (set! *tasks* (cons task *tasks*)))

(define (take-task idx)
  (cond ((not idx) #f)
        ((or (negative? idx)
             (>= idx (length *tasks*))) #f)
        (else (let ((task (list-ref *tasks* idx)))
                (set! *tasks* (delete task *tasks*))
                task))))

(define (clear-taskqueue!)
  (set! *tasks* '()))

;; Return a snapshot of the scheduler state
(define (snapshot-taskqueue)
  *tasks*)

;; Restore the scheduler state with a snapshot
(define (restore-taskqueue! tasks)
  (set! *tasks* tasks))

;; -----------------------------------------------------------------------------
;; scheduler
;; -----------------------------------------------------------------------------

(define *finisher* (make-parameter #f))
(define (make-finisher cc)
  (lambda (_) (clear-taskqueue!) (cc #f)))

;; Spawn a new task given a thunk.
(define (spawn thunk)
  (push-task (lambda (_)
               (thunk)
               (resume)
               (error "unreachable"))))

;; Yield current task and resume another.
(define (yield)
  (when (not (*finisher*))
    (error "scheduler not running"))
  (call/cc (lambda (cc)
             (push-task cc)
             (resume)
             (error "unreachable"))))

(define (resume)
  (if (no-tasks?)
      ((*finisher*) #f)
      (let* ((idx (pick-next *tasks*))
             (cc (take-task idx)))
        (cc #f))))

(define (run)
  (call/cc (lambda (cc)
             (parameterize ((*finisher* (make-finisher cc)))
               (resume)))))

;; -----------------------------------------------------------------------------
;; history
;; -----------------------------------------------------------------------------

;; history of selections form (cons idx max-idx)
(define *history* '())

;; returns true if in any scheduling point of the future,
;; there still a different option to take
(define (has-future? hist)
  (cond
    ((null? hist) #f)
    ((< (caar hist) (cdar hist)) #t)
    (else (has-future? (cdr hist)))))

;; add index selection to history
(define (add-history idx max-idx)
  (let ((item (cons idx max-idx)))
    (set! *history* (cons item *history*))))

(define (extract-history)
  (reverse *history*))

(define (swap-history!)
  (let ((prev-hist (extract-history)))
    (set! *history* '())
    prev-hist))

;; -----------------------------------------------------------------------------
;; safety checks
;; -----------------------------------------------------------------------------

(define (run-postconds postconds)
  (parameterize ((*finisher* #t)
                 (*in-atomic-block* #t))
    (for-each
      (lambda (p)
        (when (not (p))
          (error "postcondition failed")))
      (reverse postconds))))

;; -----------------------------------------------------------------------------
;; systematic exploration
;; -----------------------------------------------------------------------------

(define (pick/dfs hist)
  (cond
    ((null? hist)
     (values 0 '()))

    ; some future change point has an option
    ((has-future? (cdr hist))
     (let ((idx (caar hist)))
       (values idx (cdr hist))))

    ; there is still some option in the current change point
    ((< (caar hist) (cdar hist))
     (let ((idx (+ (caar hist) 1)))
       (values idx '())))

    (else (error "should never happen"))))

(define *prev-hist* '())
(define (pick-next lst)
  (let-values (((nxt hist) (pick/dfs *prev-hist*)))
    (set! *prev-hist* hist)
    (add-history nxt (- (length lst) 1))
    nxt))

(define (explore postconds)
  (let ((initial-taskqueue (snapshot-taskqueue))
        (initial-memory (snapshot-memory)))
    (let loop ((execution-count 0))
      (if (or (zero? execution-count)
              (has-future? (extract-history)))
          (let ((prev-hist (swap-history!)))
            (print "\n=== EXECUTION " (+ 1 execution-count) " ===")
            (set! *prev-hist* prev-hist)
            (run)
            (print "== HISTORY: " (extract-history))
            (run-postconds postconds)
            (restore-taskqueue! initial-taskqueue)
            (restore-memory! initial-memory)
            (loop (+ execution-count 1)))
          execution-count))))

(define (choose options)
  (when (not (*finisher*))
    (error "CHOOSE: scheduler not running"))
  (when (null? options)
    (error "CHOOSE: empty option list"))
  (let* ((i (pick-next options)))
    (list-ref options i)))

;; -----------------------------------------------------------------------------
;; memory
;;
;; Memory represents a system's state, it is a map of locations to values.
;; -----------------------------------------------------------------------------

(define hash-table-buckets 1024)
(define *memory* (make-hash-table equal? string-hash hash-table-buckets))

(define (init-location loc val)
  (let ((loc (location->string loc)))
    (if (hash-table-exists? *memory* loc)
        (error "reinitialized location" loc val)
        (hash-table-set! *memory* loc val))))

(define (write-location loc val)
  (let ((loc (location->string loc)))
    (if (not (hash-table-exists? *memory* loc))
        (error "write to uninitialized location" loc val)
        (hash-table-set! *memory* loc val))))

(define (read-location loc)
  (let ((loc (location->string loc)))
    (if (not (hash-table-exists? *memory* loc))
        (error "read from uninitialized location" loc)
        (hash-table-ref *memory* loc))))

(define (snapshot-memory)
  (hash-table-copy *memory*))

(define (restore-memory! mem)
  (if (not (hash-table? mem))
      (error "could not restore *memory*" mem)
      (set! *memory* (hash-table-copy mem))))

(define (location->string loc)
  (cond ((string? loc) loc)
        ((number? loc) (number->string loc))
        ((symbol? loc) (symbol->string loc))
        ((pair? loc)
         (string-append
           (location->string (car loc)) ":"
           (location->string (cdr loc))))
        (else (error "cannot convert location to string"))))
