mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-11 19:39:49 +01:00
Optimize fetching (#27)
* Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
This commit is contained in:
parent
c430e54c3c
commit
fc3d398ca9
4 changed files with 87 additions and 15 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
**/.stack-work
|
||||
.stack/
|
||||
tensorflow-mnist-input-data/data/*.gz
|
||||
|
|
|
@ -202,6 +202,20 @@ Test-Suite TypesTest
|
|||
, test-framework-quickcheck2
|
||||
, vector
|
||||
|
||||
Benchmark FeedFetchBench
|
||||
default-language: Haskell2010
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: FeedFetchBench.hs
|
||||
hs-source-dirs: tests
|
||||
build-depends: base
|
||||
, criterion
|
||||
, deepseq
|
||||
, tensorflow
|
||||
, tensorflow-ops
|
||||
, transformers
|
||||
, vector
|
||||
ghc-options: -O2 -threaded
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/tensorflow/haskell
|
||||
|
|
43
tensorflow-ops/tests/FeedFetchBench.hs
Normal file
43
tensorflow-ops/tests/FeedFetchBench.hs
Normal file
|
@ -0,0 +1,43 @@
|
|||
-- Disable full-laziness to keep ghc from optimizing most of the benchmark away.
|
||||
{-# OPTIONS_GHC -fno-full-laziness #-}
|
||||
import Control.DeepSeq (NFData(rnf))
|
||||
import Control.Exception (evaluate)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Criterion.Main (defaultMain, bgroup, bench)
|
||||
import Criterion.Types (Benchmarkable(..))
|
||||
import qualified Data.Vector as V
|
||||
import qualified TensorFlow.Core as TF
|
||||
import qualified TensorFlow.Ops as TF
|
||||
|
||||
-- | Create 'Benchmarkable' for 'TF.Session'.
|
||||
--
|
||||
-- The entire benchmark will be run in a single tensorflow session. The
|
||||
-- 'TF.Session' argument will be run once and then its result will be run N
|
||||
-- times.
|
||||
nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable
|
||||
nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
|
||||
f <- init
|
||||
-- Can't use replicateM because n is Int64.
|
||||
let go n | n <= 0 = return ()
|
||||
| otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1)
|
||||
go m
|
||||
|
||||
-- | Benchmark feeding and fetching a vector.
|
||||
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
|
||||
feedFetchBenchmark = do
|
||||
input <- TF.build (TF.placeholder (TF.Shape [-1]))
|
||||
output <- TF.build (TF.render (TF.identity input))
|
||||
return $ \v -> do
|
||||
let shape = TF.Shape [fromIntegral (V.length v)]
|
||||
inputData = TF.encodeTensorData shape v
|
||||
feeds = [TF.feed input inputData]
|
||||
TF.runWithFeeds feeds output
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain
|
||||
[ bgroup "feedFetch"
|
||||
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
|
||||
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
|
||||
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
|
||||
]
|
||||
]
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
{-# LANGUAGE DeriveDataTypeable #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module TensorFlow.Internal.FFI
|
||||
( TensorFlowException(..)
|
||||
|
@ -34,7 +35,10 @@ import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
|||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Data (Data, dataTypeName, dataTypeOf)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word8)
|
||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||
|
@ -48,6 +52,7 @@ import qualified Data.Text as T
|
|||
import qualified Data.Text.Encoding as T
|
||||
import qualified Data.Text.Encoding.Error as T
|
||||
import qualified Data.Vector.Storable as S
|
||||
import qualified Foreign.Concurrent as ForeignC
|
||||
|
||||
import Data.ProtoLens (Message, encodeMessage)
|
||||
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
||||
|
@ -133,9 +138,9 @@ run session feeds fetches targets = do
|
|||
checkStatus $ Raw.run
|
||||
session
|
||||
nullPtr
|
||||
feedNames cFeedTensors (fromIntegral feedsLen)
|
||||
fetchNames tensorOuts (fromIntegral fetchesLen)
|
||||
ctargets (fromIntegral targetsLen)
|
||||
feedNames cFeedTensors (safeConvert feedsLen)
|
||||
fetchNames tensorOuts (safeConvert fetchesLen)
|
||||
ctargets (safeConvert targetsLen)
|
||||
nullPtr
|
||||
mapM_ Raw.deleteTensor feedTensors
|
||||
outTensors <- peekArray fetchesLen tensorOuts
|
||||
|
@ -145,6 +150,17 @@ run session feeds fetches targets = do
|
|||
-- Internal.
|
||||
|
||||
|
||||
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
|
||||
safeConvert ::
|
||||
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
|
||||
=> a -> b
|
||||
safeConvert x =
|
||||
fromMaybe
|
||||
(error ("Failed to convert " ++ show x ++ ", got " ++
|
||||
show (fromIntegral x :: b)))
|
||||
(toIntegralSized x)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as a list of CString.
|
||||
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
||||
withStringList strings fn = go strings []
|
||||
|
@ -162,13 +178,13 @@ withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
|||
-- | Create a Raw.Tensor from a TensorData.
|
||||
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||
createRawTensor (TensorData dims dt byteVec) =
|
||||
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
|
||||
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
|
||||
let len = S.length byteVec
|
||||
dest <- mallocArray len
|
||||
S.unsafeWith byteVec $ \x -> copyArray dest x len
|
||||
Raw.newTensor (toEnum $ fromEnum dt)
|
||||
cdims (fromIntegral cdimsLen)
|
||||
(castPtr dest) (fromIntegral len)
|
||||
cdims (safeConvert cdimsLen)
|
||||
(castPtr dest) (safeConvert len)
|
||||
tensorDeallocFunPtr nullPtr
|
||||
|
||||
{-# NOINLINE tensorDeallocFunPtr #-}
|
||||
|
@ -186,13 +202,11 @@ createTensorData t = do
|
|||
-- Read type.
|
||||
dtype <- toEnum . fromEnum <$> Raw.tensorType t
|
||||
-- Read data.
|
||||
len <- fromIntegral <$> Raw.tensorByteSize t
|
||||
len <- safeConvert <$> Raw.tensorByteSize t
|
||||
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
||||
-- TODO(fmayle): Don't copy the data.
|
||||
v <- S.fromList <$> peekArray len bytes
|
||||
-- Free tensor.
|
||||
Raw.deleteTensor t
|
||||
return $ TensorData (map fromIntegral dims) dtype v
|
||||
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
|
||||
let v = S.unsafeFromForeignPtr0 fp len
|
||||
return $ TensorData (map safeConvert dims) dtype v
|
||||
|
||||
-- | Runs the given action which does FFI calls updating a provided
|
||||
-- status object. If the status is not OK it is thrown as
|
||||
|
@ -218,10 +232,10 @@ setSessionTarget target = B.useAsCString target . Raw.setTarget
|
|||
|
||||
-- | Serializes the given msg and provides it as (ptr,len) argument
|
||||
-- to the given action.
|
||||
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
|
||||
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
|
||||
msg -> (Ptr b -> c -> IO a) -> IO a
|
||||
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
|
||||
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
||||
|
||||
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||
-- address space.
|
||||
|
@ -234,7 +248,7 @@ getAllOpList = do
|
|||
withForeignPtr foreignPtr $
|
||||
\ptr -> B.packCStringLen =<< (,)
|
||||
<$> (castPtr <$> Raw.getBufferData ptr)
|
||||
<*> (fromIntegral <$> Raw.getBufferLength ptr)
|
||||
<*> (safeConvert <$> Raw.getBufferLength ptr)
|
||||
where
|
||||
checkCall = do
|
||||
p <- Raw.getAllOpList
|
||||
|
|
Loading…
Reference in a new issue