mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
More fixes.
This commit is contained in:
parent
63e0fae505
commit
27d7c0b069
|
@ -20,9 +20,8 @@ module Main where
|
|||
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Int (Int64)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Google.Test (googleTest)
|
||||
import TensorFlow.Types (ListOf(..), Scalar(..))
|
||||
import TensorFlow.Types (ListOf(..), Scalar(..), (|:|))
|
||||
import TensorFlow.Ops (scalar)
|
||||
import TensorFlow.Queue
|
||||
import TensorFlow.Session
|
||||
|
@ -44,11 +43,16 @@ testBasic = testCase "testBasic" $ runSession $ do
|
|||
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
|
||||
buildAnd run_ $ enqueue q $ 42 :| scalar "Hi" :| Nil
|
||||
x <- buildAnd run (dequeue q)
|
||||
liftIO $ (Identity (Scalar 42) :| Identity (Scalar "Hi") :| Nil) @=? x
|
||||
liftIO $ (Scalar 42 |:| Scalar "Hi" |:| Nil) @=? x
|
||||
|
||||
buildAnd run_ $ enqueue q $ 56 :| scalar "Bar" :| Nil
|
||||
y <- buildAnd run (dequeue q)
|
||||
let expected = Identity (Scalar 56) :| Identity (Scalar "Bar") :| Nil
|
||||
-- Note: we use explicit "Scalar" here to specify the type that was
|
||||
-- fetched. Equivalently we could write
|
||||
-- 56 |:| "Bar" |:| Nil :: List [Scalar Int64, Scalar BS.ByteString]
|
||||
-- or else allow the types to be determined by future use of the fetched
|
||||
-- value.
|
||||
let expected = Scalar 56 |:| Scalar "Bar" |:| Nil
|
||||
liftIO $ expected @=? y
|
||||
|
||||
-- | Test queue pumping.
|
||||
|
@ -64,7 +68,7 @@ testPump = testCase "testPump" $ runSession $ do
|
|||
run_ (pump, pump)
|
||||
|
||||
(x, y) <- run (deq, deq)
|
||||
let expected = Identity (Scalar 31) :| Identity (Scalar "Baz") :| Nil
|
||||
let expected = Scalar 31 |:| Scalar "Baz" |:| Nil
|
||||
liftIO $ expected @=? x
|
||||
liftIO $ expected @=? y
|
||||
|
||||
|
@ -77,7 +81,7 @@ testAsync = testCase "testAsync" $ runSession $ do
|
|||
-- Pumps the queue until canceled by runSession exiting.
|
||||
asyncProdNodes pump
|
||||
-- Picks up a couple values and verifies they are as expected.
|
||||
let expected = Identity (Scalar 10) :| Identity (Scalar "Async") :| Nil
|
||||
let expected = Scalar 10 |:| Scalar "Async" |:| Nil
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
run deq >>= liftIO . (expected @=?)
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ import Control.Monad (replicateM)
|
|||
import Control.Monad.Reader (ReaderT, runReaderT, ask)
|
||||
import Control.Monad.State.Strict (State, runState, get, put)
|
||||
import Data.Int (Int64)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Lens.Family2 ((&), (<>~), (^.))
|
||||
|
||||
import TensorFlow.Build
|
||||
|
|
|
@ -42,6 +42,7 @@ module TensorFlow.Types
|
|||
-- * Lists
|
||||
, ListOf(..)
|
||||
, List
|
||||
, (|:|)
|
||||
, TensorTypeProxy(..)
|
||||
, TensorTypes(..)
|
||||
, TensorTypeList
|
||||
|
@ -60,7 +61,7 @@ module TensorFlow.Types
|
|||
, AllTensorTypes
|
||||
) where
|
||||
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Functor.Identity (Identity(..))
|
||||
import Data.Complex (Complex)
|
||||
import Data.Default (def)
|
||||
import Data.Int (Int8, Int16, Int32, Int64)
|
||||
|
@ -421,6 +422,12 @@ instance All Show (Map f as) => Show (ListOf f as) where
|
|||
|
||||
type List = ListOf Identity
|
||||
|
||||
-- | Equivalent of ':|' for lists.
|
||||
(|:|) :: a -> List as -> List (a ': as)
|
||||
(|:|) = (:|) . Identity
|
||||
|
||||
infixr 5 |:|
|
||||
|
||||
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
||||
--
|
||||
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
||||
|
|
Loading…
Reference in New Issue
Block a user