2016-11-15 00:14:51 +01:00
|
|
|
-- Copyright 2016 TensorFlow authors.
|
|
|
|
--
|
|
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
-- you may not use this file except in compliance with the License.
|
|
|
|
-- You may obtain a copy of the License at
|
|
|
|
--
|
|
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
--
|
|
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
-- See the License for the specific language governing permissions and
|
|
|
|
-- limitations under the License.
|
|
|
|
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
|
|
-- | Testing tracing.
|
|
|
|
module Main where
|
|
|
|
|
|
|
|
import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
|
|
|
|
import Data.ByteString.Builder (toLazyByteString)
|
|
|
|
import Data.ByteString.Lazy (isPrefixOf)
|
|
|
|
import Data.Default (def)
|
|
|
|
import Lens.Family2 ((&), (.~))
|
|
|
|
import Test.Framework (defaultMain)
|
|
|
|
import Test.Framework.Providers.HUnit (testCase)
|
2016-11-18 19:42:02 +01:00
|
|
|
import Test.HUnit (assertBool, assertFailure)
|
2016-11-15 00:14:51 +01:00
|
|
|
|
|
|
|
import qualified TensorFlow.Core as TF
|
|
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
|
|
|
|
testTracing :: IO ()
|
|
|
|
testTracing = do
|
|
|
|
-- Verifies that tracing happens as a side-effect of graph extension.
|
|
|
|
loggedValue <- newEmptyMVar
|
|
|
|
TF.runSessionWithOptions
|
|
|
|
(def & TF.sessionTracer .~ putMVar loggedValue)
|
2017-03-18 20:08:53 +01:00
|
|
|
(TF.run_ (TF.scalar (0 :: Float)))
|
2016-11-15 00:14:51 +01:00
|
|
|
tryReadMVar loggedValue >>=
|
|
|
|
maybe (assertFailure "Logging never happened") expectedFormat
|
|
|
|
where expectedFormat x =
|
|
|
|
let got = toLazyByteString x in
|
|
|
|
assertBool ("Unexpected log entry " ++ show got)
|
|
|
|
("Session.extend" `isPrefixOf` got)
|
|
|
|
|
2016-11-18 19:42:02 +01:00
|
|
|
main :: IO ()
|
2016-11-15 00:14:51 +01:00
|
|
|
main = defaultMain
|
|
|
|
[ testCase "Tracing" testTracing
|
|
|
|
]
|