{-# LANGUAGE OverloadedLists #-}
module Main (main) where

import Data.Int (Int32)
import Data.Maybe (isJust)
import Control.Monad (when)
import Control.Monad.IO.Class (liftIO)
import qualified Data.Vector.Storable as V
import TensorFlow.Core
    ( unScalar
    , render
    , run_
    , runSession
    , run
    , withControlDependencies)
import qualified TensorFlow.Ops as Ops
import TensorFlow.Variable
    ( Variable
    , readValue
    , initializedValue
    , initializedVariable
    , assign
    , assignAdd
    , variable
    )
import Test.Framework (defaultMain, Test)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit ((@=?), assertFailure)

main :: IO ()
main = defaultMain
            [ testInitializedVariable
            , testInitializedVariableShape
            , testInitializedValue
            , testDependency
            , testRereadRef
            , testAssignAdd
            ]

testInitializedVariable :: Test
testInitializedVariable =
    testCase "testInitializedVariable" $ runSession $ do
        (formula, reset) <- do
            v <- initializedVariable 42
            r <- assign v 24
            return (1 + readValue v, r)
        result <- run formula
        liftIO $ 43 @=? (unScalar result :: Float)
        run_ reset  -- Updates v to a different value
        rerunResult <- run formula
        liftIO $ 25 @=? (unScalar rerunResult :: Float)

testInitializedVariableShape :: Test
testInitializedVariableShape =
    testCase "testInitializedVariableShape" $ runSession $ do
        vector <- initializedVariable (Ops.constant [1] [42 :: Float])
        result <- run (readValue vector)
        liftIO $ [42] @=? (result :: V.Vector Float)
        s <- run (Ops.shape (readValue vector))
        liftIO $ [1] @=? (s :: V.Vector Int32)

testInitializedValue :: Test
testInitializedValue =
    testCase "testInitializedValue" $ runSession $ do
        initialized <- initializedVariable (Ops.constant [1] [42 :: Float])
        result <- run (initializedValue initialized)
        liftIO $ Just [42] @=? (result :: Maybe (V.Vector Float))

        uninitialized <- variable [1]
        -- Can't use @=? because there is no Show instance for Tensor.
        when (isJust (initializedValue (uninitialized :: Variable Float))) $
            liftIO $ assertFailure "initializedValue should be Nothing, got Just"

testDependency :: Test
testDependency =
    testCase "testDependency" $ runSession $ do
        v <- variable []
        a <- assign v 24
        r <- withControlDependencies a $ render (readValue v + 18)
        result <- run r
        liftIO $ (42 :: Float) @=? unScalar result

-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ runSession $ do
    w <- initializedVariable 0
    f0 <- run (readValue w)
    run_ =<< assign w (Ops.scalar (0.1 :: Float))
    f1 <- run (readValue w)
    liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1)

testAssignAdd :: Test
testAssignAdd = testCase "testAssignAdd" $ runSession $ do
    w <- initializedVariable 42
    run_ =<< assignAdd w 17
    f1 <- run (readValue w)
    liftIO $ (42 + 17 :: Float) @=? unScalar f1