mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 11:29:43 +01:00
f170df9d13
In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
83 lines
2.7 KiB
Haskell
83 lines
2.7 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
module Main where
|
|
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Data.Int (Int64)
|
|
import Google.Test (googleTest)
|
|
import TensorFlow.Types (Scalar(..))
|
|
import TensorFlow.Ops (scalar)
|
|
import TensorFlow.Queue
|
|
import TensorFlow.Session
|
|
( asyncProdNodes
|
|
, build
|
|
, buildAnd
|
|
, run
|
|
, runSession
|
|
, run_
|
|
)
|
|
import Test.Framework (Test)
|
|
import Test.Framework.Providers.HUnit (testCase)
|
|
import Test.HUnit ((@=?))
|
|
import qualified Data.ByteString as BS
|
|
|
|
-- | Test basic queue behaviors.
|
|
testBasic :: Test
|
|
testBasic = testCase "testBasic" $ runSession $ do
|
|
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
|
|
buildAnd run_ (enqueue q 42 (scalar "Hi"))
|
|
x <- buildAnd run (dequeue q)
|
|
liftIO $ (Scalar 42, Scalar "Hi") @=? x
|
|
|
|
buildAnd run_ (enqueue q 56 (scalar "Bar"))
|
|
y <- buildAnd run (dequeue q)
|
|
liftIO $ (Scalar 56, Scalar "Bar") @=? y
|
|
|
|
-- | Test queue pumping.
|
|
testPump :: Test
|
|
testPump = testCase "testPump" $ runSession $ do
|
|
(deq, pump) <- build $ do
|
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
|
|
(,) <$> dequeue q
|
|
<*> enqueue q 31 (scalar "Baz")
|
|
-- This is a realistic use. The pump inputs are pre-bound to some
|
|
-- nodes that produce values when pumped (e.g. read from a
|
|
-- file).
|
|
run_ (pump, pump)
|
|
|
|
(x, y) <- run (deq, deq)
|
|
liftIO $ (Scalar 31, Scalar "Baz") @=? x
|
|
liftIO $ (Scalar 31, Scalar "Baz") @=? y
|
|
|
|
testAsync :: Test
|
|
testAsync = testCase "testAsync" $ runSession $ do
|
|
(deq, pump) <- build $ do
|
|
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
|
|
(,) <$> dequeue q
|
|
<*> enqueue q 10 (scalar "Async")
|
|
-- Pumps the queue until canceled by runSession exiting.
|
|
asyncProdNodes pump
|
|
-- Picks up a couple values and verifies they are as expected.
|
|
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
|
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
|
|
|
|
main :: IO ()
|
|
main = googleTest [ testBasic
|
|
, testPump
|
|
, testAsync
|
|
]
|