tensorflow-haskell/tensorflow/src/TensorFlow/Internal/FFI.hs

244 lines
9.0 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run
, TensorData(..)
, setSessionConfig
, setSessionTarget
, getAllOpList
-- * Internal helper.
, useProtoAsVoidPtrLen
)
where
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Int (Int64)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
import qualified TensorFlow.Internal.Raw as Raw
data TensorFlowException = TensorFlowException Raw.Code T.Text
deriving (Show, Eq, Typeable)
instance Exception TensorFlowException
-- | All of the data needed to represent a tensor.
data TensorData = TensorData
{ tensorDataDimensions :: [Int64]
, tensorDataType :: !DataType
, tensorDataBytes :: !(S.Vector Word8)
}
deriving (Show, Eq)
-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> IO a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> IO a
withSession optionSetter action = do
drain <- newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
finally (checkStatus (Raw.closeSession s)) $ do
runners <- takeMVar drain
-- Collects all runners before deleting the session.
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
optionSetter options
bracket
(checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
where
launchAndRecord restRunners = (: restRunners) <$> async runner
shutDownRunner :: Async () -> IO ()
shutDownRunner r = do
cancel r
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r
extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len
run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (fromIntegral feedsLen)
fetchNames tensorOuts (fromIntegral fetchesLen)
ctargets (fromIntegral targetsLen)
nullPtr
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
-- Internal.
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
cdims (fromIntegral cdimsLen)
(castPtr dest) (fromIntegral len)
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
numDims <- Raw.numDims t
dims <- mapM (Raw.dim t) [0..numDims-1]
-- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data.
len <- fromIntegral <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
-- TODO(fmayle): Don't copy the data.
v <- S.fromList <$> peekArray len bytes
-- Free tensor.
Raw.deleteTensor t
return $ TensorData (map fromIntegral dims) dtype v
-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
-- TensorFlowException.
checkStatus :: (Raw.Status -> IO a) -> IO a
checkStatus fn =
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
result <- fn status
code <- Raw.getCode status
when (code /= Raw.TF_OK) $ do
msg <- T.decodeUtf8With T.lenientDecode <$>
(Raw.message status >>= B.packCString)
throwIO $ TensorFlowException code msg
return result
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
setSessionConfig pb opt =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus (Raw.setConfig opt ptr len)
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (fromIntegral <$> Raw.getBufferLength ptr)
where
checkCall = do
p <- Raw.getAllOpList
when (p == nullPtr) (throwIO exception)
return p
exception = TensorFlowException
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"