{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}

module Main where

import Data.Iteratee as I
import Criterion.Main
import Control.Monad.Identity
import Control.Monad.Trans


runner
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner etee xs iter =
  runIdentity $ enumPureNChunk xs 5 (joinI $ etee iter) >>= I.run

-- test fusion of enumeratee/iteratee composition
runner2
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner2 etee xs iter =
  runIdentity $ enumPureNChunk xs 5 (etee =$ iter) >>= I.run

-- test fusion of enumerator/enumeratee composition
runner3
  :: Enumeratee [Int] xs Identity a
  -> [Int]
  -> Iteratee xs Identity a
  -> a
runner3 etee xs iter =
  runIdentity $ (enumPureNChunk xs 5 $= etee) iter >>= I.run

m2 :: Enumeratee [Int] [Int] Identity a
m2 = mapChunks id ><> mapChunks (map (+1))

m3 :: Enumeratee [Int] [Int] Identity a
m3 = mapChunks id ><> mapChunks (map (+1)) ><> I.filter (even)

m4 :: Enumeratee [Int] [Int] Identity a
m4 = m2 ><> m2

m10 :: Enumeratee [Int] [Int] Identity a
m10 = m3 ><> m2 ><> m3 ><> m2

fusedMap :: Iteratee [Int] Identity a -> a
fusedMap = runner m2 [1..100]

fusedMap3 :: Iteratee [Int] Identity a -> a
fusedMap3 = runner m3 [1..100]

fusedMap4 :: Iteratee [Int] Identity a -> a
fusedMap4 = runner m4 [1..100]

fusedMap10 :: Iteratee [Int] Identity a -> a
fusedMap10 = runner m10 [1..100]

-- experiment with using stream-fusion like constructs for
-- enumeratees
data StreamF m b a = forall s. StreamF (s -> m (Step s b a)) !s

data Step s b a =
    Done
  | Yield [a] !s
  | Next (b -> P2 [a] s)

data P2 a b = P2 !a !b

map_t :: Monad m => (a -> b) -> StreamF m a b
map_t fn = StreamF loop ()
  where
    loop () = return (Next (\a -> P2 [fn a] () ))
{-# INLINE map_t #-}

filter_t :: Monad m => (a -> Bool) -> StreamF m a a
filter_t pred = StreamF loop ()
  where
    loop () = return (Next (\a -> P2 (if pred a then [a] else []) ()) )
{-# INLINE filter_t #-}

cmp_t :: Monad m => StreamF m a b -> StreamF m b c -> StreamF m a c
cmp_t (StreamF fn1 s1_0) (StreamF fn2 s2_0) = StreamF loop (s1_0,s2_0,[])
  where
    loop (s1,s2,supply) = fn2 s2 >>= \r2 -> case r2 of
        Done -> return Done
        Yield cS s2' -> return $ Yield cS (s1,s2',supply)
        Next fn -> case supply of
            (b:bS) -> let P2 cS s2' = fn b
                      in  return $ Yield cS (s1,s2',bS)
            [] -> fn1 s1 >>= \r1 -> case r1 of
                Done -> return Done
                Yield aS s1' -> loop (s1', s2, aS)
                Next f -> return $ Next $ \a ->
                    let P2 bS s1'  = f a
                    in  P2 [] (s1',s2,bS)
{-# INLINE cmp_t #-}

{-
id_t :: Monad m => StreamF m a -> StreamF m a
id_t (StreamF istep s0) = StreamF loop s0
  where
    loop s = istep s >>= \r -> case r of
        Done       -> return Done
        Yield a s' -> return $ Yield a s'
        Skip  s'   -> return $ Skip s'
{-# INLINE id_t #-}

map_t :: Monad m => (a -> b) -> StreamF m a -> StreamF m b
map_t fn (StreamF istep s0) = StreamF loop s0
  where
    loop s = istep s >>= \r -> case r of
        Done       -> return   Done
        Yield a s' -> return $ Yield (fn a) s'
        Skip  s'   -> return $ Skip s'
{-# INLINE map_t #-}

filter_t :: Monad m => (a -> Bool) -> StreamF m a -> StreamF m a
filter_t pred (StreamF istep s0) = StreamF loop s0
  where
    loop s = istep s >>= \r -> case r of
        Done       -> return Done
        Yield a s' -> return $ if pred a then Yield a s' else Skip s'
        Skip s'    -> return $ Skip s'
{-# INLINE filter_t #-}

iStream :: Monad m => StreamF (Iteratee [a] m) a
iStream = StreamF loop []
  where
    loop (x:xs) = return $ Yield x xs
    loop []     = do
        r <- isStreamFinished
        case r of
          Nothing -> getChunk >>= loop
          Just _  -> return Done
{-# INLINE iStream #-}

-- this isn't really the correct type signature, but I don't know how to write
-- what it actually is.  Maybe it won't be a problem, with the correct type
-- class constraints.
etee_t :: Monad m => (forall m. Monad m => StreamF m a -> StreamF m b) -> Enumeratee [a] [b] m x
etee_t stream_fn = case stream_fn iStream of
    StreamF b_fn s0 -> unfoldConvStream fn s0
      where
        fn s = do
          stepRes <- b_fn s
          case stepRes of
            Done       -> return (s,[])
            Yield a s' -> return (s',[a])
            Skip  s'   -> return (s',[])
{-# INLINE etee_t #-}

type Trans m a b = Monad m => StreamF m a -> StreamF m b
-}


etee_t :: Monad m => StreamF m a b -> Enumeratee [a] [b] m x
etee_t stream = case stream of
    StreamF b_fn s0 -> unfoldConvStream fn (s0,[])
      where
        fn (s,[]) = do
          ck <- getChunk
          return ((s,ck),[])
        fn (s,supply@(x:xs)) = do
          stepRes <- lift $ b_fn s
          case stepRes of
            Done        -> return ((s,supply),[])
            Yield aS s' -> return ((s',supply),aS)
            Next f -> let P2 bS s' = f x
                      in  return ((s',xs),bS)

type Trans m a b = Monad m => StreamF m a b

m2_t :: Trans m Int Int
m2_t = map_t id `cmp_t` map_t (+1)

m3_t :: Trans m Int Int
m3_t = map_t id `cmp_t` map_t (+1) `cmp_t` filter_t even

m4_t :: Trans m Int Int
m4_t = m2_t `cmp_t` m2_t

fusedMap_t :: Iteratee [Int] Identity a -> a
fusedMap_t = runner_t m2_t [1..100]

fusedMap3_t :: Iteratee [Int] Identity a -> a
fusedMap3_t = runner_t m3_t [1..100]

fusedMap4_t :: Iteratee [Int] Identity a -> a
fusedMap4_t = runner_t m4_t [1..100]

runner_t
  :: (forall m. Trans m Int x)
  -> [Int]
  -> Iteratee [x] Identity a
  -> a
runner_t trans xs iter =
  runIdentity $ enumPureNChunk xs 5 (joinI $ (etee_t trans) iter) >>= I.run

fusionBenches :: [Benchmark]
fusionBenches =
  [ bench "mapChunks/mapChunks fusion"   $ whnf fusedMap I.sum
  , bench "mapChunks/filter fusion"      $ whnf fusedMap3 I.sum
  , bench "nested mapChunks/mapChunks fusion"   $ whnf fusedMap4 I.sum
  , bench "highly nested fusion"   $ whnf fusedMap10 I.sum
  , bench "stream mapChunks/mapChunks"   $ whnf fusedMap_t I.sum
  , bench "stream mapChunks/filter"      $ whnf fusedMap3_t I.sum
  , bench "stream nested mapChunks/mapChunks"   $ whnf fusedMap4_t I.sum
  ]

main :: IO ()
main = do
    putStrLn "\n\n Original"
    print $ "fusedMap"
    print $ fusedMap I.sum
    print "fusedMap/filter"
    print $ fusedMap3 I.sum
    print "fusedMap4"
    print $ fusedMap4 I.sum
    print "fusedMap10"
    print $ fusedMap10 I.sum

    putStrLn "\n\n Stream-based"
    print $ "fusedMap"
    print $ fusedMap I.sum
    print "fusedMap/filter"
    print $ fusedMap3 I.sum
    print "fusedMap4"
    print $ fusedMap4 I.sum

    defaultMain fusionBenches
