mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-27 03:59:46 +01:00
f170df9d13
In addition, you can now fetch TensorData directly. This might be useful in scenarios where you feed the result of a computation back in, like RNN. Before: benchmarking feedFetch/4 byte time 83.31 μs (81.88 μs .. 84.75 μs) 0.997 R² (0.994 R² .. 0.998 R²) mean 87.32 μs (86.06 μs .. 88.83 μs) std dev 4.580 μs (3.698 μs .. 5.567 μs) variance introduced by outliers: 55% (severely inflated) benchmarking feedFetch/4 KiB time 114.9 μs (111.5 μs .. 118.2 μs) 0.996 R² (0.994 R² .. 0.998 R²) mean 117.3 μs (116.2 μs .. 118.6 μs) std dev 3.877 μs (3.058 μs .. 5.565 μs) variance introduced by outliers: 31% (moderately inflated) benchmarking feedFetch/4 MiB time 109.0 ms (107.9 ms .. 110.7 ms) 1.000 R² (0.999 R² .. 1.000 R²) mean 108.6 ms (108.2 ms .. 109.2 ms) std dev 740.2 μs (353.2 μs .. 1.186 ms) After: benchmarking feedFetch/4 byte time 82.92 μs (80.55 μs .. 85.24 μs) 0.996 R² (0.993 R² .. 0.998 R²) mean 83.58 μs (82.34 μs .. 84.89 μs) std dev 4.327 μs (3.664 μs .. 5.375 μs) variance introduced by outliers: 54% (severely inflated) benchmarking feedFetch/4 KiB time 85.69 μs (83.81 μs .. 87.30 μs) 0.997 R² (0.996 R² .. 0.999 R²) mean 86.99 μs (86.11 μs .. 88.15 μs) std dev 3.608 μs (2.854 μs .. 5.273 μs) variance introduced by outliers: 43% (moderately inflated) benchmarking feedFetch/4 MiB time 1.582 ms (1.509 ms .. 1.677 ms) 0.970 R² (0.936 R² .. 0.993 R²) mean 1.645 ms (1.554 ms .. 1.981 ms) std dev 490.6 μs (138.9 μs .. 1.067 ms) variance introduced by outliers: 97% (severely inflated)
43 lines
1.7 KiB
Haskell
43 lines
1.7 KiB
Haskell
-- Disable full-laziness to keep ghc from optimizing most of the benchmark away.
|
|
{-# OPTIONS_GHC -fno-full-laziness #-}
|
|
import Control.DeepSeq (NFData(rnf))
|
|
import Control.Exception (evaluate)
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Criterion.Main (defaultMain, bgroup, bench)
|
|
import Criterion.Types (Benchmarkable(..))
|
|
import qualified Data.Vector.Storable as S
|
|
import qualified TensorFlow.Core as TF
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
-- | Create 'Benchmarkable' for 'TF.Session'.
|
|
--
|
|
-- The entire benchmark will be run in a single tensorflow session. The
|
|
-- 'TF.Session' argument will be run once and then its result will be run N
|
|
-- times.
|
|
nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable
|
|
nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
|
f <- init
|
|
-- Can't use replicateM because n is Int64.
|
|
let go n | n <= 0 = return ()
|
|
| otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1)
|
|
go m
|
|
|
|
-- | Benchmark feeding and fetching a vector.
|
|
feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float))
|
|
feedFetchBenchmark = do
|
|
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
|
output <- TF.build (TF.render (TF.identity input))
|
|
return $ \v -> do
|
|
let shape = TF.Shape [fromIntegral (S.length v)]
|
|
inputData = TF.encodeTensorData shape v
|
|
feeds = [TF.feed input inputData]
|
|
TF.runWithFeeds feeds output
|
|
|
|
main :: IO ()
|
|
main = defaultMain
|
|
[ bgroup "feedFetch"
|
|
[ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0)
|
|
, bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0)
|
|
, bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0)
|
|
]
|
|
]
|