-- Copyright 2016 TensorFlow authors. -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} module TensorFlow.Internal.FFI ( TensorFlowException(..) , Raw.Session , withSession , extendGraph , run , TensorData(..) , setSessionConfig , setSessionTarget , getAllOpList -- * Internal helper. , useProtoAsVoidPtrLen ) where import Control.Concurrent.Async (Async, async, cancel, waitCatch) import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar) import Control.Exception (Exception, throwIO, bracket, finally, mask_) import Control.Monad (when) import Data.Int (Int64) import Data.Typeable (Typeable) import Data.Word (Word8) import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign.C.String (CString) import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import System.IO.Unsafe (unsafePerformIO) import qualified Data.ByteString as B import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T import qualified Data.Vector.Storable as S import Data.ProtoLens (Message, encodeMessage) import Proto.Tensorflow.Core.Framework.Graph (GraphDef) import Proto.Tensorflow.Core.Framework.Types (DataType(..)) import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto) import qualified TensorFlow.Internal.Raw as Raw data TensorFlowException = TensorFlowException Raw.Code T.Text deriving (Show, Eq, Typeable) instance Exception TensorFlowException -- | All of the data needed to represent a tensor. data TensorData = TensorData { tensorDataDimensions :: [Int64] , tensorDataType :: !DataType , tensorDataBytes :: !(S.Vector Word8) } deriving (Show, Eq) -- | Runs the given action after creating a session with options -- populated by the given optionSetter. withSession :: (Raw.SessionOptions -> IO ()) -> ((IO () -> IO ()) -> Raw.Session -> IO a) -- ^ The action can spawn concurrent tasks which will -- be canceled before withSession returns. -> IO a withSession optionSetter action = do drain <- newMVar [] let cleanup s = -- Closes the session to nudge the pending run calls to fail and exit. finally (checkStatus (Raw.closeSession s)) $ do runners <- takeMVar drain -- Collects all runners before deleting the session. mapM_ shutDownRunner runners checkStatus (Raw.deleteSession s) bracket Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do optionSetter options bracket (checkStatus (Raw.newSession options)) cleanup (action (asyncCollector drain)) asyncCollector :: MVar [Async ()] -> IO () -> IO () asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord where launchAndRecord restRunners = (: restRunners) <$> async runner shutDownRunner :: Async () -> IO () shutDownRunner r = do cancel r -- TODO(gnezdo): manage exceptions better than print. either print (const (return ())) =<< waitCatch r extendGraph :: Raw.Session -> GraphDef -> IO () extendGraph session pb = useProtoAsVoidPtrLen pb $ \ptr len -> checkStatus $ Raw.extendGraph session ptr len run :: Raw.Session -> [(B.ByteString, TensorData)] -- ^ Feeds. -> [B.ByteString] -- ^ Fetches. -> [B.ByteString] -- ^ Targets. -> IO [TensorData] run session feeds fetches targets = do let nullTensor = Raw.Tensor nullPtr -- Use mask to avoid leaking input tensors before they are passed to 'run' -- and output tensors before they are passed to 'createTensorData'. mask_ $ -- Feeds withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames -> mapM (createRawTensor . snd) feeds >>= \feedTensors -> withArrayLen feedTensors $ \_ cFeedTensors -> -- Fetches. withStringArrayLen fetches $ \fetchesLen fetchNames -> -- tensorOuts is an array of null Tensor pointers that will be filled -- by the call to Raw.run. withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts -> -- Targets. withStringArrayLen targets $ \targetsLen ctargets -> do checkStatus $ Raw.run session nullPtr feedNames cFeedTensors (fromIntegral feedsLen) fetchNames tensorOuts (fromIntegral fetchesLen) ctargets (fromIntegral targetsLen) nullPtr mapM_ Raw.deleteTensor feedTensors outTensors <- peekArray fetchesLen tensorOuts mapM createTensorData outTensors -- Internal. -- | Use a list of ByteString as a list of CString. withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a withStringList strings fn = go strings [] where go [] cs = fn (reverse cs) -- TODO(fmayle): Is it worth using unsafeAsCString here? go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs) -- | Use a list of ByteString as an array of CString. withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn) -- | Create a Raw.Tensor from a TensorData. createRawTensor :: TensorData -> IO Raw.Tensor createRawTensor (TensorData dims dt byteVec) = withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do let len = S.length byteVec dest <- mallocArray len S.unsafeWith byteVec $ \x -> copyArray dest x len Raw.newTensor (toEnum $ fromEnum dt) cdims (fromIntegral cdimsLen) (castPtr dest) (fromIntegral len) tensorDeallocFunPtr nullPtr {-# NOINLINE tensorDeallocFunPtr #-} tensorDeallocFunPtr :: FunPtr Raw.TensorDeallocFn tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x -- | Create a TensorData from a Raw.Tensor. -- -- Takes ownership of the Raw.Tensor. createTensorData :: Raw.Tensor -> IO TensorData createTensorData t = do -- Read dimensions. numDims <- Raw.numDims t dims <- mapM (Raw.dim t) [0..numDims-1] -- Read type. dtype <- toEnum . fromEnum <$> Raw.tensorType t -- Read data. len <- fromIntegral <$> Raw.tensorByteSize t bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) -- TODO(fmayle): Don't copy the data. v <- S.fromList <$> peekArray len bytes -- Free tensor. Raw.deleteTensor t return $ TensorData (map fromIntegral dims) dtype v -- | Runs the given action which does FFI calls updating a provided -- status object. If the status is not OK it is thrown as -- TensorFlowException. checkStatus :: (Raw.Status -> IO a) -> IO a checkStatus fn = bracket Raw.newStatus Raw.deleteStatus $ \status -> do result <- fn status code <- Raw.getCode status when (code /= Raw.TF_OK) $ do msg <- T.decodeUtf8With T.lenientDecode <$> (Raw.message status >>= B.packCString) throwIO $ TensorFlowException code msg return result setSessionConfig :: ConfigProto -> Raw.SessionOptions -> IO () setSessionConfig pb opt = useProtoAsVoidPtrLen pb $ \ptr len -> checkStatus (Raw.setConfig opt ptr len) setSessionTarget :: B.ByteString -> Raw.SessionOptions -> IO () setSessionTarget target = B.useAsCString target . Raw.setTarget -- | Serializes the given msg and provides it as (ptr,len) argument -- to the given action. useProtoAsVoidPtrLen :: (Message msg, Num c) => msg -> (Ptr b -> c -> IO a) -> IO a useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) -> f (castPtr bytes) (fromIntegral len) -- | Returns the serialized OpList of all OpDefs defined in this -- address space. getAllOpList :: IO B.ByteString getAllOpList = do foreignPtr <- mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall) -- Makes a copy because it is more reliable than eviscerating -- Buffer to steal its memory (including custom deallocator). withForeignPtr foreignPtr $ \ptr -> B.packCStringLen =<< (,) <$> (castPtr <$> Raw.getBufferData ptr) <*> (fromIntegral <$> Raw.getBufferLength ptr) where checkCall = do p <- Raw.getAllOpList when (p == nullPtr) (throwIO exception) return p exception = TensorFlowException Raw.TF_UNKNOWN "GetAllOpList failure, check logs"