280 lines
11 KiB
Haskell
280 lines
11 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE DeriveDataTypeable #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
module TensorFlow.Internal.FFI
|
|
( TensorFlowException(..)
|
|
, Raw.Session
|
|
, withSession
|
|
, extendGraph
|
|
, run
|
|
, TensorData(..)
|
|
, setSessionConfig
|
|
, setSessionTarget
|
|
, getAllOpList
|
|
, unsafeTStringToByteString
|
|
-- * Internal helper.
|
|
, useProtoAsVoidPtrLen
|
|
)
|
|
where
|
|
|
|
import Control.Exception (assert)
|
|
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
|
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
|
import Control.Monad (when)
|
|
import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask_)
|
|
import Control.Monad.IO.Class (MonadIO, liftIO)
|
|
import Data.Bits (Bits, toIntegralSized)
|
|
import Data.Int (Int64)
|
|
import Data.Maybe (fromMaybe)
|
|
import Data.Typeable (Typeable)
|
|
import Data.Word (Word8)
|
|
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
|
import Foreign.C.String (CString)
|
|
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
|
|
import Foreign.Marshal.Alloc (free)
|
|
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
|
import System.IO.Unsafe (unsafePerformIO)
|
|
import qualified Data.ByteString as B
|
|
import qualified Data.Text as T
|
|
import qualified Data.Text.Encoding as T
|
|
import qualified Data.Text.Encoding.Error as T
|
|
import qualified Data.Vector.Storable as S
|
|
import qualified Data.Vector.Storable.Mutable as M
|
|
|
|
import Data.ProtoLens (Message, encodeMessage)
|
|
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
|
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
|
|
|
|
import qualified TensorFlow.Internal.Raw as Raw
|
|
|
|
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
|
|
-- to string into a ByteString.
|
|
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
|
|
unsafeTStringToByteString v =
|
|
assert (S.length v == Raw.sizeOfTString) $
|
|
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
|
|
let tstring = Raw.TString (castPtr tstringPtr)
|
|
p <- Raw.stringGetDataPointer tstring
|
|
n <- Raw.stringGetSize tstring
|
|
B.packCStringLen (p, fromIntegral n)
|
|
|
|
data TensorFlowException = TensorFlowException Raw.Code T.Text
|
|
deriving (Show, Eq, Typeable)
|
|
|
|
instance Exception TensorFlowException
|
|
|
|
-- | All of the data needed to represent a tensor.
|
|
data TensorData = TensorData
|
|
{ tensorDataDimensions :: [Int64]
|
|
, tensorDataType :: !DataType
|
|
, tensorDataBytes :: !(S.Vector Word8)
|
|
}
|
|
deriving (Show, Eq)
|
|
|
|
-- | Runs the given action after creating a session with options
|
|
-- populated by the given optionSetter.
|
|
withSession :: (MonadIO m, MonadMask m)
|
|
=> (Raw.SessionOptions -> IO ())
|
|
-> ((IO () -> IO ()) -> Raw.Session -> m a)
|
|
-- ^ The action can spawn concurrent tasks which will
|
|
-- be canceled before withSession returns.
|
|
-> m a
|
|
withSession optionSetter action = do
|
|
drain <- liftIO $ newMVar []
|
|
let cleanup s =
|
|
-- Closes the session to nudge the pending run calls to fail and exit.
|
|
finally (checkStatus (Raw.closeSession s)) $ do
|
|
runners <- takeMVar drain
|
|
-- Collects all runners before deleting the session.
|
|
mapM_ shutDownRunner runners
|
|
checkStatus (Raw.deleteSession s)
|
|
let bracketIO x y = bracket (liftIO x) (liftIO . y)
|
|
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
|
|
bracketIO
|
|
(optionSetter options >> checkStatus (Raw.newSession options))
|
|
cleanup
|
|
(action (asyncCollector drain))
|
|
|
|
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
|
|
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
|
|
where
|
|
launchAndRecord restRunners = (: restRunners) <$> async runner
|
|
|
|
shutDownRunner :: Async () -> IO ()
|
|
shutDownRunner r = do
|
|
cancel r
|
|
-- TODO(gnezdo): manage exceptions better than print.
|
|
either print (const (return ())) =<< waitCatch r
|
|
|
|
extendGraph :: Raw.Session -> GraphDef -> IO ()
|
|
extendGraph session pb =
|
|
useProtoAsVoidPtrLen pb $ \ptr len ->
|
|
checkStatus $ Raw.extendGraph session ptr len
|
|
|
|
|
|
run :: Raw.Session
|
|
-> [(B.ByteString, TensorData)] -- ^ Feeds.
|
|
-> [B.ByteString] -- ^ Fetches.
|
|
-> [B.ByteString] -- ^ Targets.
|
|
-> IO [TensorData]
|
|
run session feeds fetches targets = do
|
|
let nullTensor = Raw.Tensor nullPtr
|
|
-- Use mask to avoid leaking input tensors before they are passed to 'run'
|
|
-- and output tensors before they are passed to 'createTensorData'.
|
|
mask_ $
|
|
-- Feeds
|
|
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
|
|
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
|
|
withArrayLen feedTensors $ \_ cFeedTensors ->
|
|
-- Fetches.
|
|
withStringArrayLen fetches $ \fetchesLen fetchNames ->
|
|
-- tensorOuts is an array of null Tensor pointers that will be filled
|
|
-- by the call to Raw.run.
|
|
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
|
|
-- Targets.
|
|
withStringArrayLen targets $ \targetsLen ctargets -> do
|
|
checkStatus $ Raw.run
|
|
session
|
|
nullPtr
|
|
feedNames cFeedTensors (safeConvert feedsLen)
|
|
fetchNames tensorOuts (safeConvert fetchesLen)
|
|
ctargets (safeConvert targetsLen)
|
|
nullPtr
|
|
mapM_ Raw.deleteTensor feedTensors
|
|
outTensors <- peekArray fetchesLen tensorOuts
|
|
mapM createTensorData outTensors
|
|
|
|
|
|
-- Internal.
|
|
|
|
|
|
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
|
|
safeConvert ::
|
|
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
|
|
=> a -> b
|
|
safeConvert x =
|
|
fromMaybe
|
|
(error ("Failed to convert " ++ show x ++ ", got " ++
|
|
show (fromIntegral x :: b)))
|
|
(toIntegralSized x)
|
|
|
|
|
|
-- | Use a list of ByteString as a list of CString.
|
|
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
|
withStringList strings fn = go strings []
|
|
where
|
|
go [] cs = fn (reverse cs)
|
|
-- TODO(fmayle): Is it worth using unsafeAsCString here?
|
|
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
|
|
|
|
|
|
-- | Use a list of ByteString as an array of CString.
|
|
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
|
|
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
|
|
|
|
|
-- | Create a Raw.Tensor from a TensorData.
|
|
createRawTensor :: TensorData -> IO Raw.Tensor
|
|
createRawTensor (TensorData dims dt byteVec) =
|
|
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
|
|
let len = S.length byteVec
|
|
dest <- mallocArray len
|
|
S.unsafeWith byteVec $ \x -> copyArray dest x len
|
|
Raw.newTensor (toEnum $ fromEnum dt)
|
|
cdims (safeConvert cdimsLen)
|
|
(castPtr dest) (safeConvert len)
|
|
tensorDeallocFunPtr nullPtr
|
|
|
|
{-# NOINLINE tensorDeallocFunPtr #-}
|
|
tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn
|
|
tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
|
|
|
|
-- | Create a TensorData from a Raw.Tensor.
|
|
--
|
|
-- Takes ownership of the Raw.Tensor.
|
|
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
|
|
-- since the raw pointer may refer to storage inside a mutable TensorFlow
|
|
-- variable. We should avoid that copy when it's not needed; for example,
|
|
-- by making TensorData wrap an IOVector, and changing the code that uses it.
|
|
createTensorData :: Raw.Tensor -> IO TensorData
|
|
createTensorData t = do
|
|
-- Read dimensions.
|
|
numDims <- Raw.numDims t
|
|
dims <- mapM (Raw.dim t) [0..numDims-1]
|
|
-- Read type.
|
|
dtype <- toEnum . fromEnum <$> Raw.tensorType t
|
|
-- Read data.
|
|
len <- safeConvert <$> Raw.tensorByteSize t
|
|
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
|
fp <- newForeignPtr_ bytes
|
|
-- Make an explicit copy of the raw data, since it might point
|
|
-- to a mutable variable's memory.
|
|
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
|
|
Raw.deleteTensor t
|
|
return $ TensorData (map safeConvert dims) dtype v
|
|
|
|
-- | Runs the given action which does FFI calls updating a provided
|
|
-- status object. If the status is not OK it is thrown as
|
|
-- TensorFlowException.
|
|
checkStatus :: (Raw.Status -> IO a) -> IO a
|
|
checkStatus fn =
|
|
bracket Raw.newStatus Raw.deleteStatus $ \status -> do
|
|
result <- fn status
|
|
code <- Raw.getCode status
|
|
when (code /= Raw.TF_OK) $ do
|
|
msg <- T.decodeUtf8With T.lenientDecode <$>
|
|
(Raw.message status >>= B.packCString)
|
|
throwM $ TensorFlowException code msg
|
|
return result
|
|
|
|
setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO ()
|
|
setSessionConfig pb opt =
|
|
useProtoAsVoidPtrLen pb $ \ptr len ->
|
|
checkStatus (Raw.setConfig opt ptr len)
|
|
|
|
setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO ()
|
|
setSessionTarget target = B.useAsCString target . Raw.setTarget
|
|
|
|
-- | Serializes the given msg and provides it as (ptr,len) argument
|
|
-- to the given action.
|
|
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
|
|
msg -> (Ptr b -> c -> IO a) -> IO a
|
|
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
|
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
|
|
|
-- | Returns the serialized OpList of all OpDefs defined in this
|
|
-- address space.
|
|
getAllOpList :: IO B.ByteString
|
|
getAllOpList = do
|
|
foreignPtr <-
|
|
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
|
|
-- Makes a copy because it is more reliable than eviscerating
|
|
-- Buffer to steal its memory (including custom deallocator).
|
|
withForeignPtr foreignPtr $
|
|
\ptr -> B.packCStringLen =<< (,)
|
|
<$> (castPtr <$> Raw.getBufferData ptr)
|
|
<*> (safeConvert <$> Raw.getBufferLength ptr)
|
|
where
|
|
checkCall = do
|
|
p <- Raw.getAllOpList
|
|
when (p == nullPtr) (throwM exception)
|
|
return p
|
|
exception = TensorFlowException
|
|
Raw.TF_UNKNOWN "GetAllOpList failure, check logs"
|