tensorflow-haskell/tensorflow/src/TensorFlow/Internal/Raw.chs

153 lines
4.6 KiB
Plaintext

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ForeignFunctionInterface #-}
module TensorFlow.Internal.Raw where
#include "third_party/tensorflow/c/c_api.h"
import Foreign
import Foreign.C
{#enum TF_DataType as DataType {} deriving (Show, Eq) #}
{#enum TF_Code as Code {} deriving (Show, Eq) #}
-- Status.
{#pointer *TF_Status as Status newtype #}
newStatus :: IO Status
newStatus = {# call TF_NewStatus as ^ #}
deleteStatus :: Status -> IO ()
deleteStatus = {# call TF_DeleteStatus as ^ #}
setStatus :: Status -> Code -> CString -> IO ()
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
getCode :: Status -> IO Code
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
message :: Status -> IO CString
message = {# call TF_Message as ^ #}
-- Buffer.
data Buffer
{#pointer *TF_Buffer as BufferPtr -> Buffer #}
getBufferData :: BufferPtr -> IO (Ptr ())
getBufferData = {#get TF_Buffer->data #}
getBufferLength :: BufferPtr -> IO CULong
getBufferLength ={#get TF_Buffer->length #}
-- Tensor.
{#pointer *TF_Tensor as Tensor newtype #}
instance Storable Tensor where
sizeOf (Tensor t) = sizeOf t
alignment (Tensor t) = alignment t
peek p = fmap Tensor (peek (castPtr p))
poke p (Tensor t) = poke (castPtr p) t
newTensor :: DataType
-> Ptr CLong -- dimensions array
-> CInt -- num dimensions
-> Ptr () -- data
-> CULong -- data len
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
-> Ptr () -- deallocator arg
-> IO Tensor
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
deleteTensor :: Tensor -> IO ()
deleteTensor = {# call TF_DeleteTensor as ^ #}
tensorType :: Tensor -> IO DataType
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
numDims :: Tensor -> IO CInt
numDims = {# call TF_NumDims as ^ #}
dim :: Tensor -> CInt -> IO CLong
dim = {# call TF_Dim as ^ #}
tensorByteSize :: Tensor -> IO CULong
tensorByteSize = {# call TF_TensorByteSize as ^ #}
tensorData :: Tensor -> IO (Ptr ())
tensorData = {# call TF_TensorData as ^ #}
-- Session Options.
{# pointer *TF_SessionOptions as SessionOptions newtype #}
newSessionOptions :: IO SessionOptions
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
setTarget :: SessionOptions -> CString -> IO ()
setTarget = {# call TF_SetTarget as ^ #}
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
setConfig = {# call TF_SetConfig as ^ #}
deleteSessionOptions :: SessionOptions -> IO ()
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
-- Session.
{# pointer *TF_Session as Session newtype #}
newSession :: SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}
closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #}
deleteSession :: Session -> Status -> IO ()
deleteSession = {# call TF_DeleteSession as ^ #}
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
extendGraph = {# call TF_ExtendGraph as ^ #}
run :: Session
-> BufferPtr -- RunOptions proto.
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
-> Ptr CString -> CInt -- Target nodes (names, count).
-> BufferPtr -- RunMetadata proto.
-> Status
-> IO ()
run = {# call TF_Run as ^ #}
-- FFI helpers.
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
foreign import ccall "wrapper"
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
-- | Get the OpList of all OpDefs defined in this address space.
-- Returns a BufferPtr, ownership of which is transferred to the caller
-- (and can be freed using deleteBuffer).
--
-- The data in the buffer will be the serialized OpList proto for ops registered
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}
foreign import ccall "&TF_DeleteBuffer"
deleteBuffer :: FunPtr (BufferPtr -> IO ())