mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Use deepseq + cleaner imports
This commit is contained in:
parent
3661f329ef
commit
5616b1b008
|
@ -192,7 +192,13 @@ Benchmark FeedFetchBench
|
|||
type: exitcode-stdio-1.0
|
||||
main-is: FeedFetchBench.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: base, criterion, tensorflow, tensorflow-ops, vector, transformers
|
||||
build-depends: base
|
||||
, criterion
|
||||
, deepseq
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, transformers
|
||||
, vector
|
||||
ghc-options: -O2 -threaded
|
||||
|
||||
source-repository head
|
||||
|
|
|
@ -1,31 +1,25 @@
|
|||
-- Disable full-laziness to keep ghc from optimizing most of the benchmark away.
|
||||
{-# OPTIONS_GHC -fno-full-laziness #-}
|
||||
import Control.DeepSeq (NFData(rnf))
|
||||
import Control.Exception (evaluate)
|
||||
import Control.Monad (replicateM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Criterion.Main (defaultMain, bgroup, bench, nfIO)
|
||||
import Criterion.Main (defaultMain, bgroup, bench)
|
||||
import Criterion.Types (Benchmarkable(..))
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Build as TF
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.Gradient as TF
|
||||
import qualified TensorFlow.Nodes as TF
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
import qualified TensorFlow.Types as TF
|
||||
|
||||
-- | Create 'Benchmarkable' for 'TF.Session'.
|
||||
--
|
||||
-- The entire benchmark will be run in a single tensorflow session. The
|
||||
-- 'TF.Session' argument will be run once and then its result will be run N
|
||||
-- times.
|
||||
whnfSession :: TF.Session (a -> TF.Session b) -> a -> Benchmarkable
|
||||
whnfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
||||
nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable
|
||||
nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
||||
f <- init
|
||||
-- Can't use replicateM because n is Int64.
|
||||
let go n | n <= 0 = return ()
|
||||
| otherwise = f x >>= liftIO . evaluate >> go (n-1)
|
||||
| otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1)
|
||||
go m
|
||||
|
||||
-- | Benchmark feeding and fetching a vector.
|
||||
|
@ -42,8 +36,8 @@ feedFetchBenchmark = do
|
|||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ bgroup "feedFetch"
|
||||
[ bench "4 byte" $ whnfSession feedFetchBenchmark (V.replicate 1 0)
|
||||
, bench "4 KiB" $ whnfSession feedFetchBenchmark (V.replicate 1024 0)
|
||||
, bench "4 MiB" $ whnfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
||||
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
|
||||
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
|
||||
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
||||
]
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue
Block a user