1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-03 03:23:37 +02:00
tensorflow-haskell/tensorflow/src/TensorFlow/Internal/FFI.hs

258 lines
9.5 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
{-# LANGUAGE ScopedTypeVariables #-}
2016-10-24 21:26:42 +02:00
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
-- * Internal helper.
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
import Data.Bits (Bits, toIntegralSized)
2016-10-24 21:26:42 +02:00
import Data.Int (Int64)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
import Data.Maybe (fromMaybe)
2016-10-24 21:26:42 +02:00
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
import qualified Foreign.Concurrent as ForeignC
2016-10-24 21:26:42 +02:00
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
-- | All of the data needed to represent a tensor.
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> IO a
withSession optionSetter action = do
drain <- newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
-- Collects all runners before deleting the session.
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
optionSetter options
bracket
(checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
2016-10-24 21:26:42 +02:00
nullPtr
mapM_ Raw.deleteTensor feedTensors
2016-10-24 21:26:42 +02:00
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
-- Internal.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert x =
fromMaybe
(error ("Failed to convert " ++ show x ++ ", got " ++
show (fromIntegral x :: b)))
(toIntegralSized x)
2016-10-24 21:26:42 +02:00
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
2016-10-24 21:26:42 +02:00
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
cdims (safeConvert cdimsLen)
(castPtr dest) (safeConvert len)
2016-10-24 21:26:42 +02:00
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims-1]
-- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
len <- safeConvert <$> Raw.tensorByteSize t
2016-10-24 21:26:42 +02:00
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
let v = S.unsafeFromForeignPtr0 fp len
return $ TensorData (map safeConvert dims) dtype v
2016-10-24 21:26:42 +02:00
-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
2016-10-24 21:26:42 +02:00
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
2016-10-24 21:26:42 +02:00
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
<*> (safeConvert <$> Raw.getBufferLength ptr)
2016-10-24 21:26:42 +02:00
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"