; Copyright (c) 2024-2026 Diogo Behrens
;
; Permission to use, copy, modify, and/or distribute this software for any
; purpose with or without fee is hereby granted.
;
; THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
; REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
; AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
; INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
; LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
; OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
; PERFORMANCE OF THIS SOFTWARE.

(import (scheme base)
        (scheme cxr)
        (scheme write)
        (srfi 1)
        (srfi 69))

(define-syntax print
  (syntax-rules ()
    ((_ ARG ...)
     (begin
       (display ARG) ...
       (newline)))))

;; -----------------------------------------------------------------------------
;; atomic block
;; -----------------------------------------------------------------------------

(define *in-atomic-block* (make-parameter #f))

(define-syntax atomic-block
  (syntax-rules ()
    ((_ SEXP ...)
     (parameterize ((*in-atomic-block* #t))
       SEXP ...))))

(define (in-atomic-block?)
  (*in-atomic-block*))

;; ----------------------------------------------------------------------------
;; DSL
;; ----------------------------------------------------------------------------

;;; Cooperative yield point.
;;; If not in an atomic block, suspends the current task and resumes another.
;;; Optional metadata `md` is recorded for exploration/reduction (see sleepsets).
(define (YIELD . md)
  (unless (in-atomic-block?)
    (yield md)))

;;; Spawn a new task to run `thunk` (a 0-argument procedure).
(define (SPAWN thunk)
  (spawn thunk))

;;; Explore all executions (subject to pruning).
;;; `postconds` are 0-argument predicates checked after each execution.
(define (EXPLORE . postconds)
  (explore postconds))

;;; Initialize a shared memory location `loc` with value `val`.
;;; Typically called before spawning/exploring; re-initialization is an error.
(define (INIT loc val)
  (init-location loc val))

;;; Write value `val` to location `loc`.
;;; Yields before the write unless inside an atomic block.
;;; The location must be initialized before being written.
(define (WRITE loc val)
  (YIELD 'WRITE loc)
  (print "WRITE " loc " " val)
  (write-location loc val)
  (signal-location loc))

;;; Read and return the value of location `loc`.
;;; Yields before the read unless inside an atomic block.
;;; The location must be initialized before being read.
(define (READ loc)
  (YIELD 'READ loc)
  (let ((val (read-location loc)))
    (print "READ " loc " " val)
    val))

;;; Atomic block (indivisible w.r.t. other tasks).
;;; Performs an initial yield, then runs the body with yields suppressed.
(define-syntax ATOMIC
  (syntax-rules ()
    ((_ SEXP ...)
     (begin
       (YIELD)
       (atomic-block SEXP ...)))))

;;; Atomic read-modify-write on `loc`.
;;; Yields once at the call site (unless already atomic), then performs:
;;;   old := READ(loc); WRITE(loc, (op old val)); return old
(define (RMW op loc val)
  (YIELD 'RMW loc val)
  (atomic-block
    (let ((v (READ loc)))
      (WRITE loc (op v val))
      v)))

;;; Non-deterministically pick one of `options` (used to model choices).
(define (CHOOSE . options)
  (choose options))

;;; Assertion helper for postconditions / task code.
(define (ASSERT cnd msg)
  (unless cnd
    (error "ASSERT" msg)))

;;; Block until `loc` is signaled (typically by a WRITE to `loc`).
;;; Cannot be called inside an atomic block.
(define (AWAIT loc)
  (await-location loc))

;; -----------------------------------------------------------------------------
;; taskqueue
;; -----------------------------------------------------------------------------

(define *tasks* '())

(define (no-tasks?)
  (null? *tasks*))

(define (push-task task)
  (when (memq task *tasks*)
    (error "push-task: duplicate continuation"))
  (set! *tasks* (cons task *tasks*)))

(define (take-task idx)
  (cond ((not idx) #f)
        ((or (negative? idx)
             (>= idx (length *tasks*))) #f)
        (else (let ((task (list-ref *tasks* idx)))
                (set! *tasks* (delete task *tasks*))
                task))))

(define (clear-taskqueue!)
  (set! *tasks* '()))

(define (make-task cc tid md)
  (unless (procedure? cc)
    (error "malformed task (expected continuation procedure)" cc))
  (unless (integer? tid)
    (error "malformed task (expected integer tid)" tid))
  (unless (list? md)
    (error "malformed task (expected metadata list)" md))
  (append (list cc tid) md))

(define (task-cc task)
  (if (and (pair? task) (procedure? (car task)))
      (car task)
      (error "malformed task (expected (cc tid ...))" task)))

;; Return a snapshot of the scheduler state
(define (snapshot-taskqueue)
  *tasks*)

;; Restore the scheduler state with a snapshot
(define (restore-taskqueue! tasks)
  (set! *tasks* tasks))

;; -----------------------------------------------------------------------------
;; task id
;; -----------------------------------------------------------------------------

(define *next-tid* 1)
(define (reset-tid)
  (set! *next-tid* 0))
(define (new-tid)
  (let ((tid *next-tid*))
    (set! *next-tid* (+ 1 tid))
    tid))

(define *current-tid* (make-parameter 0))

;; -----------------------------------------------------------------------------
;; scheduler
;; -----------------------------------------------------------------------------

(define *finisher* (make-parameter #f))
(define (make-finisher cc)
  (lambda (_) (clear-taskqueue!) (cc #f)))

;; Spawn a new task given a thunk.
(define (spawn thunk)
  (parameterize ((*current-tid* (new-tid)))
    (let ((task (lambda (_)
                  (thunk)
                  (resume)
                  (error "unreachable"))))
      (push-task (make-task task (*current-tid*) '())))))

;; Yield current task and resume another.
(define (yield md)
  (when (not (*finisher*))
    (error "scheduler not running"))
  (call/cc (lambda (cc)
             (push-task (make-task cc (*current-tid*) md))
             (resume)
             (error "unreachable"))))

(define (resume)
  (if (no-tasks?)
      ((*finisher*) #f)
      (let* ((idx (pick-next-tasks)))
        (if (not idx)
            ((*finisher*) #f)   ;; prune: treat as terminal for this execution
            (let* ((task (take-task idx))
                   (cc (task-cc task))
                   (tid (task-tid task)))
              (parameterize ((*current-tid* tid))
                (cc #f)))))))

(define (run)
  (call/cc (lambda (cc)
             (parameterize ((*finisher* (make-finisher cc)))
               (resume)))))

;; -----------------------------------------------------------------------------
;; history
;; -----------------------------------------------------------------------------

;; history of selections form (cons idx max-idx)
(define *history* '())

;; returns true if in any scheduling point of the future,
;; there still a different option to take
(define (has-future? hist)
  (cond
    ((null? hist) #f)
    ((< (caar hist) (cdar hist)) #t)
    (else (has-future? (cdr hist)))))

;; add index selection to history
(define (add-history idx max-idx)
  (let ((item (cons idx max-idx)))
    (set! *history* (cons item *history*))))

(define (extract-history)
  (reverse *history*))

(define (swap-history!)
  (let ((prev-hist (extract-history)))
    (set! *history* '())
    prev-hist))

;; -----------------------------------------------------------------------------
;; safety checks
;; -----------------------------------------------------------------------------

(define (run-postconds postconds)
  (parameterize ((*finisher* #t)
                 (*in-atomic-block* #t))
    (for-each
      (lambda (p)
        (when (not (p))
          (error "postcondition failed")))
      (reverse postconds))))

;; -----------------------------------------------------------------------------
;; sleepsets
;; -----------------------------------------------------------------------------

(define *sleepset-disabled* #f)
(define *sleepset* #f)
(define (reset-sleepset!)
  (set! *sleepset* '()))

(define (task-tid task)
  (if (and (pair? task)
           (pair? (cdr task))
           (integer? (cadr task)))
      (cadr task)
      (error "malformed task (expected (cc tid ...))" task)))
(define (task-md t)
  (cdr t))
(define (task-op t)
  (cadr (task-md t)))
(define (task-loc t)
  (caddr (task-md t)))

(define (assert-unique-task-tids tasks)
  (let loop ((seen '()) (rest tasks))
    (cond
      ((null? rest) #t)
      (else
        (let ((tid (task-tid (car rest))))
          (when (memv tid seen)
            (error "duplicate tid in taskqueue" tid))
          (loop (cons tid seen) (cdr rest)))))))

(define (memory-access? t)
  (and (pair? t)
       (not (null? t))
       (pair? (cdr t))
       (memq (cadr t) '(READ WRITE RMW))))

(define (commute? t1 t2)
  (cond
    ((not (and (memory-access? (task-md t1))
               (memory-access? (task-md t2)))) #f)
    (else (if *sleepset-disabled*
              #f
              (let ((op1 (task-op t1)) (loc1 (task-loc t1))
                    (op2 (task-op t2)) (loc2 (task-loc t2)))
                (or (not (equal? loc1 loc2))
                    (and (eq? op1 'READ) (eq? op2 'READ))))))))

(define (sleepset-child E Z i)
  (let* ((t1 (list-ref E i))
         (candidates (append
                       ;; keep sleeping threads that still commute with `md-i`
                       (filter (lambda (t) (memv (task-tid t) Z)) E)
                       ;; add enabled threads (before i) that commute with `md-i`
                       (take E i)))
         (tids (map task-tid
                 (filter (lambda (t2)
                           (commute? t1 t2))
                         candidates))))
    (delete-duplicates tids =)))

(define (next-nonsleep-index E Z i)
  (let loop ((k i))
    (if (>= k (length E))
        #f
        (let ((tid (task-tid (list-ref E k))))
          (if (memv tid Z)
              (loop (+ k 1))
              k)))))

;; -----------------------------------------------------------------------------
;; systematic exploration
;; -----------------------------------------------------------------------------

(define (pick/dfs hist)
  (cond
    ((null? hist)
     (values 0 '()))

    ; some future change point has an option
    ((has-future? (cdr hist))
     (let ((idx (caar hist)))
       (values idx (cdr hist))))

    ; there is still some option in the current change point
    ((< (caar hist) (cdar hist))
     (let ((idx (+ (caar hist) 1)))
       (values idx '())))

    (else (error "should never happen"))))

(define *prev-hist* '())
(define (pick-next lst)
  (let-values (((nxt hist) (pick/dfs *prev-hist*)))
    (set! *prev-hist* hist)
    (add-history nxt (- (length lst) 1))
    nxt))

(define (pick-next-tasks)
  (let* ((E *tasks*)
         (n (length E))
         (Z *sleepset*))

    (assert-unique-task-tids E)

    (let-values (((idx hist) (pick/dfs *prev-hist*)))
      (set! *prev-hist* hist)

      (let ((k (next-nonsleep-index E Z idx)))
        (if (not k)
            ; no choice left: prune this node
            #f
            ; compute child sleepset and push
            (let ((Z´ (sleepset-child E Z k)))
              (set! *sleepset* Z´)
              (add-history k (- n 1))
              k))))))

(define (explore postconds)
  (let ((initial-taskqueue (snapshot-taskqueue))
        (initial-memory (snapshot-memory))
        (initial-monitor (snapshot-monitor)))
    (let loop ((execution-count 0))
      (if (or (zero? execution-count)
              (has-future? (extract-history)))
          (let ((prev-hist (swap-history!)))
            (print "\n=== EXECUTION " (+ 1 execution-count) " ===")
            (reset-sleepset!)
            (set! *prev-hist* prev-hist)
            (run)
            (print "== HISTORY: " (extract-history))
            (run-postconds postconds)
            (restore-taskqueue! initial-taskqueue)
            (restore-memory! initial-memory)
            (restore-monitor! initial-monitor)
            (loop (+ execution-count 1)))
          execution-count))))

(define (choose options)
  (when (not (*finisher*))
    (error "CHOOSE: scheduler not running"))
  (when (null? options)
    (error "CHOOSE: empty option list"))
  (let* ((i (pick-next options)))
    (list-ref options i)))

;; -----------------------------------------------------------------------------
;; memory
;;
;; Memory represents a system's state, it is a map of locations to values.
;; -----------------------------------------------------------------------------

(define hash-table-buckets 1024)
(define *memory* (make-hash-table equal? string-hash hash-table-buckets))

(define (init-location loc val)
  (let ((loc (location->string loc)))
    (if (hash-table-exists? *memory* loc)
        (error "reinitialized location" loc val)
        (hash-table-set! *memory* loc val))))

(define (write-location loc val)
  (let ((loc (location->string loc)))
    (if (not (hash-table-exists? *memory* loc))
        (error "write to uninitialized location" loc val)
        (hash-table-set! *memory* loc val))))

(define (read-location loc)
  (let ((loc (location->string loc)))
    (if (not (hash-table-exists? *memory* loc))
        (error "read from uninitialized location" loc)
        (hash-table-ref *memory* loc))))

(define (snapshot-memory)
  (hash-table-copy *memory*))

(define (restore-memory! mem)
  (if (not (hash-table? mem))
      (error "could not restore *memory*" mem)
      (set! *memory* (hash-table-copy mem))))

(define (location->string loc)
  (cond ((string? loc) loc)
        ((number? loc) (number->string loc))
        ((symbol? loc) (symbol->string loc))
        ((pair? loc)
         (string-append
           (location->string (car loc)) ":"
           (location->string (cdr loc))))
        (else (error "cannot convert location to string"))))

;; -----------------------------------------------------------------------------
;; monitor
;; -----------------------------------------------------------------------------

(define *loc-waiters* (make-hash-table equal? string-hash 1024))
(define (await-location loc)
  (let* ((loc (location->string loc))
         (waiters (hash-table-ref/default *loc-waiters* loc '())))
    (call/cc (lambda (cc)
               (let ((task (make-task cc (*current-tid*) '())))
                 (hash-table-set! *loc-waiters* loc (cons task waiters))
                 (if (no-tasks?)
                     (error "AWAIT: termination violation")
                     (resume)))))))
(define (signal-location loc)
  (let* ((loc (location->string loc))
         (waiters (hash-table-ref/default *loc-waiters* loc '())))
    (for-each push-task waiters)
    (hash-table-delete! *loc-waiters* loc)))

(define (snapshot-monitor)
  (hash-table-copy *loc-waiters*))
(define (restore-monitor! x)
  (set! *loc-waiters* (hash-table-copy x)))
