-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE OverloadedStrings #-}

-- | Testing tracing.
module Main where

import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Lazy (isPrefixOf)
import Data.Default (def)
import Lens.Family2 ((&), (.~))
import Test.Framework (defaultMain)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit (assertBool, assertFailure)

import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF

testTracing :: IO ()
testTracing = do
    -- Verifies that tracing happens as a side-effect of graph extension.
    loggedValue <- newEmptyMVar
    TF.runSessionWithOptions
        (def & TF.sessionTracer .~ putMVar loggedValue)
        (TF.run_ (TF.scalar (0 :: Float)))
    tryReadMVar loggedValue >>=
        maybe (assertFailure "Logging never happened") expectedFormat
  where expectedFormat x =
            let got = toLazyByteString x in
            assertBool ("Unexpected log entry " ++ show got)
                       ("Session.extend" `isPrefixOf` got)

main :: IO ()
main = defaultMain
    [ testCase "Tracing" testTracing
    ]