tensorflow-haskell/tensorflow/src/TensorFlow/Internal/FFI.hs

280 lines
11 KiB
Haskell
Raw Normal View History

2016-10-24 21:26:42 +02:00
-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
{-# LANGUAGE ScopedTypeVariables #-}
2016-10-24 21:26:42 +02:00
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
, unsafeTStringToByteString
2016-10-24 21:26:42 +02:00
-- * Internal helper.
, useProtoAsVoidPtrLen
)
where
import Control.Exception (assert)
2016-10-24 21:26:42 +02:00
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
import Control.Monad.IO.Class (MonadIO, liftIO)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
import Data.Bits (Bits, toIntegralSized)
2016-10-24 21:26:42 +02:00
import Data.Int (Int64)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
import Data.Maybe (fromMaybe)
2016-10-24 21:26:42 +02:00
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
2016-10-24 21:26:42 +02:00
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Data.Vector.Storable.Mutable as M
2016-10-24 21:26:42 +02:00
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
-- to string into a ByteString.
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
unsafeTStringToByteString v =
assert (S.length v == Raw.sizeOfTString) $
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
let tstring = Raw.TString (castPtr tstringPtr)
p <- Raw.stringGetDataPointer tstring
n <- Raw.stringGetSize tstring
B.packCStringLen (p, fromIntegral n)
2016-10-24 21:26:42 +02:00
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
-- | All of the data needed to represent a tensor.
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
2016-10-24 21:26:42 +02:00
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> m a
2016-10-24 21:26:42 +02:00
withSession optionSetter action = do
drain <- liftIO $ newMVar []
2016-10-24 21:26:42 +02:00
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
-- Collects all runners before deleting the session.
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (Raw.newSession options))
2016-10-24 21:26:42 +02:00
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
2016-10-24 21:26:42 +02:00
nullPtr
mapM_ Raw.deleteTensor feedTensors
2016-10-24 21:26:42 +02:00
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
-- Internal.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert x =
fromMaybe
(error ("Failed to convert " ++ show x ++ ", got " ++
show (fromIntegral x :: b)))
(toIntegralSized x)
2016-10-24 21:26:42 +02:00
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
2016-10-24 21:26:42 +02:00
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
cdims (safeConvert cdimsLen)
(castPtr dest) (safeConvert len)
2016-10-24 21:26:42 +02:00
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
-- since the raw pointer may refer to storage inside a mutable TensorFlow
-- variable. We should avoid that copy when it's not needed; for example,
-- by making TensorData wrap an IOVector, and changing the code that uses it.
2016-10-24 21:26:42 +02:00
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims-1]
-- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
len <- safeConvert <$> Raw.tensorByteSize t
2016-10-24 21:26:42 +02:00
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
fp <- newForeignPtr_ bytes
-- Make an explicit copy of the raw data, since it might point
-- to a mutable variable's memory.
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
Raw.deleteTensor t
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
return $ TensorData (map safeConvert dims) dtype v
2016-10-24 21:26:42 +02:00
-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwM $ TensorFlowException code msg
2016-10-24 21:26:42 +02:00
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
2016-10-24 21:26:42 +02:00
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
2016-10-24 21:26:42 +02:00
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
Optimize fetching (#27) * Add MNIST data to gitignore * Add simple tensor round-trip benchmark * Use deepseq + cleaner imports * Use safe version of fromIntegral in FFI code * Don't copy data when fetching tensors BEFORE benchmarking feedFetch/4 byte time 55.79 μs (54.88 μs .. 56.62 μs) 0.998 R² (0.997 R² .. 0.999 R²) mean 55.61 μs (55.09 μs .. 56.11 μs) std dev 1.828 μs (1.424 μs .. 2.518 μs) variance introduced by outliers: 34% (moderately inflated) benchmarking feedFetch/4 KiB time 231.4 μs (221.9 μs .. 247.3 μs) 0.988 R² (0.974 R² .. 1.000 R²) mean 226.6 μs (224.1 μs .. 236.2 μs) std dev 13.45 μs (7.115 μs .. 27.14 μs) variance introduced by outliers: 57% (severely inflated) benchmarking feedFetch/4 MiB time 485.8 ms (424.6 ms .. 526.7 ms) 0.998 R² (0.994 R² .. 1.000 R²) mean 515.7 ms (512.5 ms .. 517.9 ms) std dev 3.320 ms (0.0 s .. 3.822 ms) variance introduced by outliers: 19% (moderately inflated) AFTER benchmarking feedFetch/4 byte time 53.11 μs (52.12 μs .. 54.22 μs) 0.996 R² (0.995 R² .. 0.998 R²) mean 54.64 μs (53.59 μs .. 56.18 μs) std dev 4.249 μs (2.910 μs .. 6.076 μs) variance introduced by outliers: 75% (severely inflated) benchmarking feedFetch/4 KiB time 83.83 μs (82.72 μs .. 84.92 μs) 0.999 R² (0.998 R² .. 0.999 R²) mean 83.82 μs (83.20 μs .. 84.35 μs) std dev 1.943 μs (1.557 μs .. 2.614 μs) variance introduced by outliers: 20% (moderately inflated) benchmarking feedFetch/4 MiB time 95.54 ms (93.62 ms .. 97.82 ms) 0.999 R² (0.998 R² .. 1.000 R²) mean 96.61 ms (95.76 ms .. 97.51 ms) std dev 1.408 ms (1.005 ms .. 1.889 ms)
2016-11-17 19:41:49 +01:00
<*> (safeConvert <$> Raw.getBufferLength ptr)
2016-10-24 21:26:42 +02:00
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwM exception)
2016-10-24 21:26:42 +02:00
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"