Optimize fetching (#27)

* Add MNIST data to gitignore
* Add simple tensor round-trip benchmark
* Use deepseq + cleaner imports
* Use safe version of fromIntegral in FFI code
* Don't copy data when fetching tensors

BEFORE

benchmarking feedFetch/4 byte
time                 55.79 μs   (54.88 μs .. 56.62 μs)
                     0.998 R²   (0.997 R² .. 0.999 R²)
mean                 55.61 μs   (55.09 μs .. 56.11 μs)
std dev              1.828 μs   (1.424 μs .. 2.518 μs)
variance introduced by outliers: 34% (moderately inflated)

benchmarking feedFetch/4 KiB
time                 231.4 μs   (221.9 μs .. 247.3 μs)
                     0.988 R²   (0.974 R² .. 1.000 R²)
mean                 226.6 μs   (224.1 μs .. 236.2 μs)
std dev              13.45 μs   (7.115 μs .. 27.14 μs)
variance introduced by outliers: 57% (severely inflated)

benchmarking feedFetch/4 MiB
time                 485.8 ms   (424.6 ms .. 526.7 ms)
                     0.998 R²   (0.994 R² .. 1.000 R²)
mean                 515.7 ms   (512.5 ms .. 517.9 ms)
std dev              3.320 ms   (0.0 s .. 3.822 ms)
variance introduced by outliers: 19% (moderately inflated)

AFTER

benchmarking feedFetch/4 byte
time                 53.11 μs   (52.12 μs .. 54.22 μs)
                     0.996 R²   (0.995 R² .. 0.998 R²)
mean                 54.64 μs   (53.59 μs .. 56.18 μs)
std dev              4.249 μs   (2.910 μs .. 6.076 μs)
variance introduced by outliers: 75% (severely inflated)

benchmarking feedFetch/4 KiB
time                 83.83 μs   (82.72 μs .. 84.92 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 83.82 μs   (83.20 μs .. 84.35 μs)
std dev              1.943 μs   (1.557 μs .. 2.614 μs)
variance introduced by outliers: 20% (moderately inflated)

benchmarking feedFetch/4 MiB
time                 95.54 ms   (93.62 ms .. 97.82 ms)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 96.61 ms   (95.76 ms .. 97.51 ms)
std dev              1.408 ms   (1.005 ms .. 1.889 ms)
This commit is contained in:
fkm3 2016-11-17 10:41:49 -08:00 committed by Greg Steuck
parent c430e54c3c
commit fc3d398ca9
4 changed files with 87 additions and 15 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
**/.stack-work **/.stack-work
.stack/ .stack/
tensorflow-mnist-input-data/data/*.gz

View File

@ -202,6 +202,20 @@ Test-Suite TypesTest
, test-framework-quickcheck2 , test-framework-quickcheck2
, vector , vector
Benchmark FeedFetchBench
default-language: Haskell2010
type: exitcode-stdio-1.0
main-is: FeedFetchBench.hs
hs-source-dirs: tests
build-depends: base
, criterion
, deepseq
, tensorflow
, tensorflow-ops
, transformers
, vector
ghc-options: -O2 -threaded
source-repository head source-repository head
type: git type: git
location: https://github.com/tensorflow/haskell location: https://github.com/tensorflow/haskell

View File

@ -0,0 +1,43 @@
-- Disable full-laziness to keep ghc from optimizing most of the benchmark away.
{-# OPTIONS_GHC -fno-full-laziness #-}
import Control.DeepSeq (NFData(rnf))
import Control.Exception (evaluate)
import Control.Monad.IO.Class (liftIO)
import Criterion.Main (defaultMain, bgroup, bench)
import Criterion.Types (Benchmarkable(..))
import qualified Data.Vector as V
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF
-- | Create 'Benchmarkable' for 'TF.Session'.
--
-- The entire benchmark will be run in a single tensorflow session. The
-- 'TF.Session' argument will be run once and then its result will be run N
-- times.
nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable
nfSession init x = Benchmarkable $ \m -> TF.runSession $ do
f <- init
-- Can't use replicateM because n is Int64.
let go n | n <= 0 = return ()
| otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1)
go m
-- | Benchmark feeding and fetching a vector.
feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float))
feedFetchBenchmark = do
input <- TF.build (TF.placeholder (TF.Shape [-1]))
output <- TF.build (TF.render (TF.identity input))
return $ \v -> do
let shape = TF.Shape [fromIntegral (V.length v)]
inputData = TF.encodeTensorData shape v
feeds = [TF.feed input inputData]
TF.runWithFeeds feeds output
main :: IO ()
main = defaultMain
[ bgroup "feedFetch"
[ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0)
, bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0)
, bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0)
]
]

View File

@ -14,6 +14,7 @@
{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Internal.FFI module TensorFlow.Internal.FFI
( TensorFlowException(..) ( TensorFlowException(..)
@ -34,7 +35,10 @@ import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_) import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when) import Control.Monad (when)
import Data.Bits (Bits, toIntegralSized)
import Data.Data (Data, dataTypeName, dataTypeOf)
import Data.Int (Int64) import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import Data.Word (Word8) import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign (Ptr, FunPtr, nullPtr, castPtr)
@ -48,6 +52,7 @@ import qualified Data.Text as T
import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S import qualified Data.Vector.Storable as S
import qualified Foreign.Concurrent as ForeignC
import Data.ProtoLens (Message, encodeMessage) import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef) import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
@ -133,9 +138,9 @@ run session feeds fetches targets = do
checkStatus $ Raw.run checkStatus $ Raw.run
session session
nullPtr nullPtr
feedNames cFeedTensors (fromIntegral feedsLen) feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (fromIntegral fetchesLen) fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (fromIntegral targetsLen) ctargets (safeConvert targetsLen)
nullPtr nullPtr
mapM_ Raw.deleteTensor feedTensors mapM_ Raw.deleteTensor feedTensors
outTensors <- peekArray fetchesLen tensorOuts outTensors <- peekArray fetchesLen tensorOuts
@ -145,6 +150,17 @@ run session feeds fetches targets = do
-- Internal. -- Internal.
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert x =
fromMaybe
(error ("Failed to convert " ++ show x ++ ", got " ++
show (fromIntegral x :: b)))
(toIntegralSized x)
-- | Use a list of ByteString as a list of CString. -- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings [] withStringList strings fn = go strings []
@ -162,13 +178,13 @@ withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData. -- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) = createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
let len = S.length byteVec let len = S.length byteVec
dest <- mallocArray len dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt) Raw.newTensor (toEnum $ fromEnum dt)
cdims (fromIntegral cdimsLen) cdims (safeConvert cdimsLen)
(castPtr dest) (fromIntegral len) (castPtr dest) (safeConvert len)
tensorDeallocFunPtr nullPtr tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-} {-# NOINLINE tensorDeallocFunPtr #-}
@ -186,13 +202,11 @@ createTensorData t = do
-- Read type. -- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data. -- Read data.
len <- fromIntegral <$> Raw.tensorByteSize t len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
-- TODO(fmayle): Don't copy the data. fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
v <- S.fromList <$> peekArray len bytes let v = S.unsafeFromForeignPtr0 fp len
-- Free tensor. return $ TensorData (map safeConvert dims) dtype v
Raw.deleteTensor t
return $ TensorData (map fromIntegral dims) dtype v
-- | Runs the given action which does FFI calls updating a provided -- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as -- status object. If the status is not OK it is thrown as
@ -218,10 +232,10 @@ setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument -- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action. -- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Num c) => useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (fromIntegral len) \(bytes, len) -> f (castPtr bytes) (safeConvert len)
-- | Returns the serialized OpList of all OpDefs defined in this -- | Returns the serialized OpList of all OpDefs defined in this
-- address space. -- address space.
@ -234,7 +248,7 @@ getAllOpList = do
withForeignPtr foreignPtr $ withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,) \ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr) <$> (castPtr <$> Raw.getBufferData ptr)
<*> (fromIntegral <$> Raw.getBufferLength ptr) <*> (safeConvert <$> Raw.getBufferLength ptr)
where where
checkCall = do checkCall = do
p <- Raw.getAllOpList p <- Raw.getAllOpList