diff --git a/.gitignore b/.gitignore index f6172ea..a029f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ **/.stack-work .stack/ +tensorflow-mnist-input-data/data/*.gz diff --git a/tensorflow-ops/tensorflow-ops.cabal b/tensorflow-ops/tensorflow-ops.cabal index 3ab99b2..e3905af 100644 --- a/tensorflow-ops/tensorflow-ops.cabal +++ b/tensorflow-ops/tensorflow-ops.cabal @@ -202,6 +202,20 @@ Test-Suite TypesTest , test-framework-quickcheck2 , vector +Benchmark FeedFetchBench + default-language: Haskell2010 + type: exitcode-stdio-1.0 + main-is: FeedFetchBench.hs + hs-source-dirs: tests + build-depends: base + , criterion + , deepseq + , tensorflow + , tensorflow-ops + , transformers + , vector + ghc-options: -O2 -threaded + source-repository head type: git location: https://github.com/tensorflow/haskell diff --git a/tensorflow-ops/tests/FeedFetchBench.hs b/tensorflow-ops/tests/FeedFetchBench.hs new file mode 100644 index 0000000..c7877b3 --- /dev/null +++ b/tensorflow-ops/tests/FeedFetchBench.hs @@ -0,0 +1,43 @@ +-- Disable full-laziness to keep ghc from optimizing most of the benchmark away. +{-# OPTIONS_GHC -fno-full-laziness #-} +import Control.DeepSeq (NFData(rnf)) +import Control.Exception (evaluate) +import Control.Monad.IO.Class (liftIO) +import Criterion.Main (defaultMain, bgroup, bench) +import Criterion.Types (Benchmarkable(..)) +import qualified Data.Vector as V +import qualified TensorFlow.Core as TF +import qualified TensorFlow.Ops as TF + +-- | Create 'Benchmarkable' for 'TF.Session'. +-- +-- The entire benchmark will be run in a single tensorflow session. The +-- 'TF.Session' argument will be run once and then its result will be run N +-- times. +nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable +nfSession init x = Benchmarkable $ \m -> TF.runSession $ do + f <- init + -- Can't use replicateM because n is Int64. + let go n | n <= 0 = return () + | otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1) + go m + +-- | Benchmark feeding and fetching a vector. +feedFetchBenchmark :: TF.Session (V.Vector Float -> TF.Session (V.Vector Float)) +feedFetchBenchmark = do + input <- TF.build (TF.placeholder (TF.Shape [-1])) + output <- TF.build (TF.render (TF.identity input)) + return $ \v -> do + let shape = TF.Shape [fromIntegral (V.length v)] + inputData = TF.encodeTensorData shape v + feeds = [TF.feed input inputData] + TF.runWithFeeds feeds output + +main :: IO () +main = defaultMain + [ bgroup "feedFetch" + [ bench "4 byte" $ nfSession feedFetchBenchmark (V.replicate 1 0) + , bench "4 KiB" $ nfSession feedFetchBenchmark (V.replicate 1024 0) + , bench "4 MiB" $ nfSession feedFetchBenchmark (V.replicate (1024^2) 0) + ] + ] diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 441fbfb..9eeb982 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -14,6 +14,7 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} module TensorFlow.Internal.FFI ( TensorFlowException(..) @@ -34,7 +35,10 @@ import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Exception (Exception, throwIO, bracket, finally, mask_) import Control.Monad (when) +import Data.Bits (Bits, toIntegralSized) +import Data.Data (Data, dataTypeName, dataTypeOf) import Data.Int (Int64) +import Data.Maybe (fromMaybe) import Data.Typeable (Typeable) import Data.Word (Word8) import Foreign (Ptr, FunPtr, nullPtr, castPtr) @@ -48,6 +52,7 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T import qualified Data.Vector.Storable as S +import qualified Foreign.Concurrent as ForeignC import Data.ProtoLens (Message, encodeMessage) import Proto.Tensorflow.Core.Framework.Graph (GraphDef) @@ -133,9 +138,9 @@ run session feeds fetches targets = do checkStatus $ Raw.run session nullPtr - feedNames cFeedTensors (fromIntegral feedsLen) - fetchNames tensorOuts (fromIntegral fetchesLen) - ctargets (fromIntegral targetsLen) + feedNames cFeedTensors (safeConvert feedsLen) + fetchNames tensorOuts (safeConvert fetchesLen) + ctargets (safeConvert targetsLen) nullPtr mapM_ Raw.deleteTensor feedTensors outTensors <- peekArray fetchesLen tensorOuts @@ -145,6 +150,17 @@ run session feeds fetches targets = do -- Internal. +-- | Same as 'fromIntegral', but throws an error if conversion is "lossy". +safeConvert :: + forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b) + => a -> b +safeConvert x = + fromMaybe + (error ("Failed to convert " ++ show x ++ ", got " ++ + show (fromIntegral x :: b))) + (toIntegralSized x) + + -- | Use a list of ByteString as a list of CString. withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a withStringList strings fn = go strings [] @@ -162,13 +178,13 @@ withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn) -- | Create a Raw.Tensor from a TensorData. createRawTensor :: TensorData -> IO Raw.Tensor createRawTensor (TensorData dims dt byteVec) = - withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do + withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do let len = S.length byteVec dest <- mallocArray len S.unsafeWith byteVec $ \x -> copyArray dest x len Raw.newTensor (toEnum $ fromEnum dt) - cdims (fromIntegral cdimsLen) - (castPtr dest) (fromIntegral len) + cdims (safeConvert cdimsLen) + (castPtr dest) (safeConvert len) tensorDeallocFunPtr nullPtr {-# NOINLINE tensorDeallocFunPtr #-} @@ -186,13 +202,11 @@ createTensorData t = do -- Read type. dtype <- toEnum . fromEnum <$> Raw.tensorType t -- Read data. - len <- fromIntegral <$> Raw.tensorByteSize t + len <- safeConvert <$> Raw.tensorByteSize t bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) - -- TODO(fmayle): Don't copy the data. - v <- S.fromList <$> peekArray len bytes - -- Free tensor. - Raw.deleteTensor t - return $ TensorData (map fromIntegral dims) dtype v + fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t) + let v = S.unsafeFromForeignPtr0 fp len + return $ TensorData (map safeConvert dims) dtype v -- | Runs the given action which does FFI calls updating a provided -- status object. If the status is not OK it is thrown as @@ -218,10 +232,10 @@ setSessionTarget target = B.useAsCString target . Raw.setTarget -- | Serializes the given msg and provides it as (ptr,len) argument -- to the given action. -useProtoAsVoidPtrLen :: (Message msg, Num c) => +useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) => msg -> (Ptr b -> c -> IO a) -> IO a useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ - \(bytes, len) -> f (castPtr bytes) (fromIntegral len) + \(bytes, len) -> f (castPtr bytes) (safeConvert len) -- | Returns the serialized OpList of all OpDefs defined in this -- address space. @@ -234,7 +248,7 @@ getAllOpList = do withForeignPtr foreignPtr $ \ptr -> B.packCStringLen =<< (,) <$> (castPtr <$> Raw.getBufferData ptr) - <*> (fromIntegral <$> Raw.getBufferLength ptr) + <*> (safeConvert <$> Raw.getBufferLength ptr) where checkCall = do p <- Raw.getAllOpList