1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-14 23:19:43 +01:00
tensorflow-haskell/tensorflow-ops/tests/TracingTest.hs
Judah Jacobson 2c5c879037 Introduce a MonadBuild class, and remove buildAnd. (#83)
This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
2017-03-18 12:08:53 -07:00

49 lines
1.7 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE OverloadedStrings #-}
-- | Testing tracing.
module Main where
import Control.Concurrent.MVar (newEmptyMVar, putMVar, tryReadMVar)
import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Lazy (isPrefixOf)
import Data.Default (def)
import Lens.Family2 ((&), (.~))
import Test.Framework (defaultMain)
import Test.Framework.Providers.HUnit (testCase)
import Test.HUnit (assertBool, assertFailure)
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF
testTracing :: IO ()
testTracing = do
-- Verifies that tracing happens as a side-effect of graph extension.
loggedValue <- newEmptyMVar
TF.runSessionWithOptions
(def & TF.sessionTracer .~ putMVar loggedValue)
(TF.run_ (TF.scalar (0 :: Float)))
tryReadMVar loggedValue >>=
maybe (assertFailure "Logging never happened") expectedFormat
where expectedFormat x =
let got = toLazyByteString x in
assertBool ("Unexpected log entry " ++ show got)
("Session.extend" `isPrefixOf` got)
main :: IO ()
main = defaultMain
[ testCase "Tracing" testTracing
]