tensorflow-haskell/tensorflow/src/TensorFlow/Internal/Raw.chs

259 lines
8.4 KiB
Plaintext

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ForeignFunctionInterface #-}
module TensorFlow.Internal.Raw where
#include "third_party/tensorflow/c/c_api.h"
import Data.ByteString (ByteString, packCString, useAsCString)
import Foreign
import Foreign.C
{# enum TF_DataType as DataType {} deriving (Show, Eq) #}
{# enum TF_Code as Code {} deriving (Show, Eq) #}
-- Status.
{# pointer *TF_Status as Status newtype #}
newStatus :: IO Status
newStatus = {# call TF_NewStatus as ^ #}
deleteStatus :: Status -> IO ()
deleteStatus = {# call TF_DeleteStatus as ^ #}
setStatus :: Status -> Code -> CString -> IO ()
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
getCode :: Status -> IO Code
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
message :: Status -> IO CString
message = {# call TF_Message as ^ #}
-- TString.
{# pointer *TF_TString as TString newtype #}
sizeOfTString :: Int
sizeOfTString = 24
-- TF_TString_Type::TF_TSTR_OFFSET
tstringOffsetTypeTag :: Word32
tstringOffsetTypeTag = 2
stringGetDataPointer :: TString -> IO CString
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
stringGetSize :: TString -> IO CULong
stringGetSize = {# call TF_StringGetSize as ^ #}
-- Operation.
{# pointer *TF_Operation as Operation newtype #}
{# fun TF_OperationName as operationName { `Operation' } -> `ByteString' packCString* #}
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
instance Storable Operation where
sizeOf (Operation t) = sizeOf t
alignment (Operation t) = alignment t
peek p = fmap Operation (peek (castPtr p))
poke p (Operation t) = poke (castPtr p) t
-- Output.
data Output = Output
{ outputOperation :: Operation
, outputIndex :: CInt
}
{# pointer *TF_Output as OutputPtr -> Output #}
instance Storable Output where
sizeOf _ = {# sizeof TF_Output #}
alignment _ = {# alignof TF_Output #}
peek p = Output <$> {# get TF_Output->oper #} p
<*> (fromIntegral <$> {# get TF_Output->index #} p)
poke p (Output oper index) = do
{# set TF_Output->oper #} p oper
{# set TF_Output->index #} p $ fromIntegral index
-- Buffer.
data Buffer
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
getBufferData :: BufferPtr -> IO (Ptr ())
getBufferData = {# get TF_Buffer->data #}
getBufferLength :: BufferPtr -> IO CULong
getBufferLength = {# get TF_Buffer->length #}
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
deleteBuffer :: BufferPtr -> IO ()
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
-- Tensor.
{# pointer *TF_Tensor as Tensor newtype #}
instance Storable Tensor where
sizeOf (Tensor t) = sizeOf t
alignment (Tensor t) = alignment t
peek p = fmap Tensor (peek (castPtr p))
poke p (Tensor t) = poke (castPtr p) t
-- A synonym for the int64_t type, which is used in the TensorFlow API.
-- On some platforms it's `long`; on others (e.g., Mac OS X) it's `long long`;
-- and as far as Haskell is concerned, those are distinct types (`CLong` vs
-- `CLLong`).
type CInt64 = {#type int64_t #}
{# pointer *size_t as CSizePtr -> CSize #}
newTensor :: DataType
-> Ptr CInt64 -- dimensions array
-> CInt -- num dimensions
-> Ptr () -- data
-> CULong -- data len
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
-> Ptr () -- deallocator arg
-> IO Tensor
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
deleteTensor :: Tensor -> IO ()
deleteTensor = {# call TF_DeleteTensor as ^ #}
tensorType :: Tensor -> IO DataType
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
numDims :: Tensor -> IO CInt
numDims = {# call TF_NumDims as ^ #}
dim :: Tensor -> CInt -> IO CInt64
dim = {# call TF_Dim as ^ #}
tensorByteSize :: Tensor -> IO CULong
tensorByteSize = {# call TF_TensorByteSize as ^ #}
tensorData :: Tensor -> IO (Ptr ())
tensorData = {# call TF_TensorData as ^ #}
-- ImportGraphDefOptions.
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping
{ `ImportGraphDefOptions'
, useAsCString* `ByteString'
, `Int'
, %`OutputPtr'
} -> `()'
#}
-- Graph.
{# pointer *TF_Graph as Graph newtype #}
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
{# fun TF_GraphOperationByName as graphOperationByName
{ `Graph'
, useAsCString* `ByteString'
} -> `Operation'
#}
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
-- Session Options.
{# pointer *TF_SessionOptions as SessionOptions newtype #}
newSessionOptions :: IO SessionOptions
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
setTarget :: SessionOptions -> CString -> IO ()
setTarget = {# call TF_SetTarget as ^ #}
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
setConfig = {# call TF_SetConfig as ^ #}
deleteSessionOptions :: SessionOptions -> IO ()
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
-- Session.
{# pointer *TF_Session as Session newtype #}
newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
{ `SessionOptions'
, `BufferPtr' -- RunOptions proto.
, useAsCString* `ByteString' -- Export directory.
, withStringArrayLen* `[ByteString]'& -- Tags.
, `Graph'
, `BufferPtr' -- MetaGraphDef.
, `Status'
} -> `Session'
#}
closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #}
deleteSession :: Session -> Status -> IO ()
deleteSession = {# call TF_DeleteSession as ^ #}
run :: Session
-> BufferPtr -- RunOptions proto.
-> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr Operation -> CInt -- Target operations (ops, count).
-> BufferPtr -- RunMetadata proto.
-> Status
-> IO ()
run = {# call TF_SessionRun as ^ #}
-- FFI helpers.
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
foreign import ccall "wrapper"
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
-- | Get the OpList of all OpDefs defined in this address space.
-- Returns a BufferPtr, ownership of which is transferred to the caller
-- (and can be freed using deleteBuffer).
--
-- The data in the buffer will be the serialized OpList proto for ops registered
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}
-- | Use a list of ByteString as a list of CString.
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)
-- | Use a list of ByteString as an array of CString with its length.
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
withStringArrayLen xs fn =
withStringList xs $ \strings ->
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)