mirror of
https://github.com/tensorflow/haskell.git
synced 2025-02-17 05:25:05 +01:00
* Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
43 lines
1.7 KiB
Haskell
43 lines
1.7 KiB
Haskell
-- Disable full-laziness to keep ghc from optimizing most of the benchmark away.
|
|
{-# OPTIONS_GHC -fno-full-laziness #-}
|
|
import Control.DeepSeq (NFData(rnf))
|
|
import Control.Exception (evaluate)
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Criterion.Main (defaultMain, bgroup, bench)
|
|
import Criterion.Types (Benchmarkable(..))
|
|
import qualified Data.Vector as V
|
|
import qualified TensorFlow.Core as TF
|
|
import qualified TensorFlow.Ops as TF
|
|
|
|
-- | Create 'Benchmarkable' for 'TF.Session'.
|
|
--
|
|
-- The entire benchmark will be run in a single tensorflow session. The
|
|
-- 'TF.Session' argument will be run once and then its result will be run N
|
|
-- times.
|
|
nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable
|
|
nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
|
f <- init
|
|
-- Can't use replicateM because n is Int64.
|
|
let go n | n <= 0 = return ()
|
|
| otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1)
|
|
go m
|
|
|
|
-- | Benchmark feeding and fetching a vector.
|
|
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
|
|
feedFetchBenchmark = do
|
|
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
|
output <- TF.build (TF.render (TF.identity input))
|
|
return $ \v -> do
|
|
let shape = TF.Shape [fromIntegral (V.length v)]
|
|
inputData = TF.encodeTensorData shape v
|
|
feeds = [TF.feed input inputData]
|
|
TF.runWithFeeds feeds output
|
|
|
|
main :: IO ()
|
|
main = defaultMain
|
|
[ bgroup "feedFetch"
|
|
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
|
|
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
|
|
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
|
]
|
|
]
|