1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

More fixes.

This commit is contained in:
Judah Jacobson 2017-03-03 10:14:01 -08:00
parent 63e0fae505
commit 27d7c0b069
3 changed files with 18 additions and 8 deletions

View File

@ -20,9 +20,8 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Data.Functor.Identity (Identity(..))
import Google.Test (googleTest)
import TensorFlow.Types (ListOf(..), Scalar(..))
import TensorFlow.Types (ListOf(..), Scalar(..), (|:|))
import TensorFlow.Ops (scalar)
import TensorFlow.Queue
import TensorFlow.Session
@ -44,11 +43,16 @@ testBasic = testCase "testBasic" $ runSession $ do
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
buildAnd run_ $ enqueue q $ 42 :| scalar "Hi" :| Nil
x <- buildAnd run (dequeue q)
liftIO $ (Identity (Scalar 42) :| Identity (Scalar "Hi") :| Nil) @=? x
liftIO $ (Scalar 42 |:| Scalar "Hi" |:| Nil) @=? x
buildAnd run_ $ enqueue q $ 56 :| scalar "Bar" :| Nil
y <- buildAnd run (dequeue q)
let expected = Identity (Scalar 56) :| Identity (Scalar "Bar") :| Nil
-- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write
-- 56 |:| "Bar" |:| Nil :: List [Scalar Int64, Scalar BS.ByteString]
-- or else allow the types to be determined by future use of the fetched
-- value.
let expected = Scalar 56 |:| Scalar "Bar" |:| Nil
liftIO $ expected @=? y
-- | Test queue pumping.
@ -64,7 +68,7 @@ testPump = testCase "testPump" $ runSession $ do
run_ (pump, pump)
(x, y) <- run (deq, deq)
let expected = Identity (Scalar 31) :| Identity (Scalar "Baz") :| Nil
let expected = Scalar 31 |:| Scalar "Baz" |:| Nil
liftIO $ expected @=? x
liftIO $ expected @=? y
@ -77,7 +81,7 @@ testAsync = testCase "testAsync" $ runSession $ do
-- Pumps the queue until canceled by runSession exiting.
asyncProdNodes pump
-- Picks up a couple values and verifies they are as expected.
let expected = Identity (Scalar 10) :| Identity (Scalar "Async") :| Nil
let expected = Scalar 10 |:| Scalar "Async" |:| Nil
run deq >>= liftIO . (expected @=?)
run deq >>= liftIO . (expected @=?)

View File

@ -30,7 +30,6 @@ import Control.Monad (replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, runState, get, put)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build

View File

@ -42,6 +42,7 @@ module TensorFlow.Types
-- * Lists
, ListOf(..)
, List
, (|:|)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
@ -60,7 +61,7 @@ module TensorFlow.Types
, AllTensorTypes
) where
import Data.Functor.Identity (Identity)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
@ -421,6 +422,12 @@ instance All Show (Map f as) => Show (ListOf f as) where
type List = ListOf Identity
-- | Equivalent of ':|' for lists.
(|:|) :: a -> List as -> List (a ': as)
(|:|) = (:|) . Identity
infixr 5 |:|
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the