Appendix: Can I pet your DAG?

{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}

module Scratch.Scheduling where

import Control.Monad (void)
import Data.Array qualified as Array
import Data.Bifunctor (Bifunctor (..))
import Data.Coerce (coerce)
import Data.Graph qualified as Graph
import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.Hashable (Hashable)
import Data.List (sort)
import Data.Map qualified as Map
import Data.Set qualified as Set
import GHC.Generics (Generic)

data Task'
  = GoOutside
  | GoToSchool
  | TakePhysics
  | TakeExam
  | GoHome
  | GetGroceries
  | UnpackGroceries
  deriving (Eq, Ord, Bounded, Enum, Show, Generic, Hashable)

happensAfter' :: Map.Map Task' [Task']
happensAfter' =
  Map.fromList $
    coerce
      [ (GoOutside, [])
      , (GoToSchool, [GoOutside])
      , (TakePhysics, [GoToSchool])
      , (TakeExam, [TakePhysics])
      , (GetGroceries, [GoOutside])
      , (GoHome, [TakeExam, GetGroceries])
      , (UnpackGroceries, [GoHome])
      ]

happensBefore' :: Map.Map Task' [Task']
happensBefore' =
  Map.fromListWith (<>) $
    [ (d, [task])
    | (task, deps) <- Map.toList happensAfter'
    , d <- deps
    ]

dependsOn :: Task' -> Task' -> Bool
dependsOn t1 t2 = case Map.lookup t1 happensAfter' of
  Nothing -> False
  Just deps -> t2 `elem` deps || any (`dependsOn` t2) deps

dependencyDAG :: Array.Array Int [Int]
dependencyDAG =
  Array.array (fromEnum (minBound :: Task'), fromEnum (maxBound :: Task')) $
    map (bimap fromEnum (fmap fromEnum)) $
      Map.toList happensAfter'

topologicSortedTasks :: [Task']
topologicSortedTasks = toEnum <$> Graph.reverseTopSort dependencyDAG

totallyOrderedTasks :: Map.Map Task' Word
totallyOrderedTasks = Map.fromList $ zip topologicSortedTasks [0 ..]

newtype Task = Task Task'
  deriving newtype (Eq, Enum, Bounded, Show, Hashable)

instance Ord Task where
  compare t1 t2 =
    (totallyOrderedTasks Map.! coerce t1)
      `compare` (totallyOrderedTasks Map.! coerce t2)

happensAfter :: HashMap.HashMap Task (HashSet.HashSet Task)
happensAfter =
  HashMap.fromList $
    map (bimap coerce (HashSet.fromList . coerce)) $
      Map.toList happensAfter'

happensBefore :: HashMap.HashMap Task (HashSet.HashSet Task)
happensBefore =
  HashMap.fromList $
    map (bimap coerce (HashSet.fromList . coerce)) $
      Map.toList happensBefore'

data Actor = Billy | Bobby
  deriving (Eq, Ord, Show)
type Action task = (task, Actor)

billyTodo :: [Action Task]
billyTodo =
  map
    (first coerce . (,Billy))
    [GoOutside, GoToSchool, TakePhysics, TakeExam, GoHome]

bobbyTodo :: [Action Task]
bobbyTodo =
  map
    (first coerce . (,Bobby))
    [GoOutside, GetGroceries, GoHome, UnpackGroceries]

allTodo :: [Action Task]
allTodo = billyTodo <> bobbyTodo

scheduleLinear :: [Action Task]
scheduleLinear = sort allTodo

scheduleNaive :: [Action Task']
scheduleNaive = reverse $ go [] $ map (first coerce) allTodo
 where
  go acc (act@(t, _) : actions)
    | not (any (dependsOn t . fst) actions) =
        go (act : acc) actions
    | otherwise = go acc (actions <> [act])
  go acc _ = acc

tasks :: Set.Set (Action Task)
tasks = Set.fromList allTodo

actionQueue :: Set.Set (Action Task)
actionQueue = Set.fromList allTodo

nextAction :: Set.Set a -> Maybe (a, Set.Set a)
nextAction = Set.minView

newActionQueue :: Set.Set (Task, Actor)
newActionQueue = Set.insert (coerce UnpackGroceries, Billy) actionQueue

data NextTask = NextTask
  { dependencyCounts :: HashMap.HashMap Task Int
  , roots :: HashSet.HashSet Task
  }
  deriving (Show, Eq)

empty :: NextTask
empty = NextTask mempty mempty

insertTask :: Task -> NextTask -> NextTask
insertTask t s | t `HashMap.member` dependencyCounts s = s
insertTask t s =
  let
    setRoot s' = s'{roots = t `HashSet.insert` roots s'}
    removeRoots which s' = s'{roots = roots s' `HashSet.difference` which}
    insertDependencies number s' =
      (if number == 0 then setRoot else id)
        s'
          { dependencyCounts = HashMap.insert t number (dependencyCounts s')
          }
    incrementExisting existing s' =
      removeRoots existing $
        s'
          { dependencyCounts =
              HashMap.unionWith
                (+)
                (fmap (const 1) (HashSet.toMap existing))
                (dependencyCounts s')
          }
    insertWithExistingDependencies before =
      insertDependencies
        ( HashSet.size
            (before `HashSet.intersection` HashMap.keysSet (dependencyCounts s))
        )
  in
    case coerce t `HashMap.lookup` happensBefore of
      Nothing -> case coerce t `HashMap.lookup` happensAfter of
        Nothing -> insertDependencies 0 s
        Just beforeTask -> insertWithExistingDependencies beforeTask s
      Just afterTask ->
        let existingAfterT = afterTask `HashSet.intersection` HashMap.keysSet (dependencyCounts s)
        in  case coerce t `HashMap.lookup` happensAfter of
              Nothing ->
                insertDependencies 0 $
                  incrementExisting existingAfterT s
              Just beforeTask ->
                insertWithExistingDependencies beforeTask $
                  incrementExisting existingAfterT s

popTask :: NextTask -> (Maybe Task, NextTask)
popTask s = case HashSet.toList (roots s) of
  [] -> (Nothing, s)
  (x : xs) ->
    let relevantDependencies deps =
          HashMap.intersection (dependencyCounts s) $
            HashSet.toMap deps
        newCounts deps =
          HashMap.unionWith
            (\ex _ -> ex - 1)
            (dependencyCounts s)
            (relevantDependencies deps)
        newRoots dependents =
          HashMap.filter (== 0) $
            HashMap.intersection (newCounts dependents) $
              relevantDependencies dependents
    in  case x `HashMap.lookup` happensBefore of
          Nothing -> (Just x, NextTask (dependencyCounts s) (HashSet.fromList xs))
          Just dependents ->
            ( Just x
            , NextTask
                (x `HashMap.delete` newCounts dependents)
                (HashSet.fromList xs <> HashSet.fromMap (void (newRoots dependents)))
            )