mirror of
https://github.com/tensorflow/haskell.git
synced 2025-03-31 02:55:13 +02:00
Includes temporary blacklisting for a couple of ops that will be supported once my fix lands in the main tensorflow repo.
158 lines
4.9 KiB
Text
158 lines
4.9 KiB
Text
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ForeignFunctionInterface #-}
|
|
|
|
module TensorFlow.Internal.Raw where
|
|
|
|
#include "third_party/tensorflow/c/c_api.h"
|
|
|
|
import Foreign
|
|
import Foreign.C
|
|
|
|
{#enum TF_DataType as DataType {} deriving (Show, Eq) #}
|
|
{#enum TF_Code as Code {} deriving (Show, Eq) #}
|
|
|
|
|
|
-- Status.
|
|
{#pointer *TF_Status as Status newtype #}
|
|
|
|
newStatus :: IO Status
|
|
newStatus = {# call TF_NewStatus as ^ #}
|
|
|
|
deleteStatus :: Status -> IO ()
|
|
deleteStatus = {# call TF_DeleteStatus as ^ #}
|
|
|
|
setStatus :: Status -> Code -> CString -> IO ()
|
|
setStatus s c = {# call TF_SetStatus as ^ #} s (fromIntegral $ fromEnum c)
|
|
|
|
getCode :: Status -> IO Code
|
|
getCode s = toEnum . fromIntegral <$> {# call TF_GetCode as ^ #} s
|
|
|
|
message :: Status -> IO CString
|
|
message = {# call TF_Message as ^ #}
|
|
|
|
|
|
-- Buffer.
|
|
data Buffer
|
|
{#pointer *TF_Buffer as BufferPtr -> Buffer #}
|
|
|
|
getBufferData :: BufferPtr -> IO (Ptr ())
|
|
getBufferData = {#get TF_Buffer->data #}
|
|
|
|
getBufferLength :: BufferPtr -> IO CULong
|
|
getBufferLength ={#get TF_Buffer->length #}
|
|
|
|
-- Tensor.
|
|
{#pointer *TF_Tensor as Tensor newtype #}
|
|
|
|
instance Storable Tensor where
|
|
sizeOf (Tensor t) = sizeOf t
|
|
alignment (Tensor t) = alignment t
|
|
peek p = fmap Tensor (peek (castPtr p))
|
|
poke p (Tensor t) = poke (castPtr p) t
|
|
|
|
-- A synonym for the int64_t type, which is used in the TensorFlow API.
|
|
-- On some platforms it's `long`; on others (e.g., Mac OS X) it's `long long`;
|
|
-- and as far as Haskell is concerned, those are distinct types (`CLong` vs
|
|
-- `CLLong`).
|
|
type CInt64 = {#type int64_t #}
|
|
|
|
newTensor :: DataType
|
|
-> Ptr CInt64 -- dimensions array
|
|
-> CInt -- num dimensions
|
|
-> Ptr () -- data
|
|
-> CULong -- data len
|
|
-> FunPtr (Ptr () -> CULong -> Ptr () -> IO ()) -- deallocator
|
|
-> Ptr () -- deallocator arg
|
|
-> IO Tensor
|
|
newTensor dt = {# call TF_NewTensor as ^ #} (fromIntegral $ fromEnum dt)
|
|
|
|
deleteTensor :: Tensor -> IO ()
|
|
deleteTensor = {# call TF_DeleteTensor as ^ #}
|
|
|
|
tensorType :: Tensor -> IO DataType
|
|
tensorType t = toEnum . fromIntegral <$> {# call TF_TensorType as ^ #} t
|
|
|
|
numDims :: Tensor -> IO CInt
|
|
numDims = {# call TF_NumDims as ^ #}
|
|
|
|
dim :: Tensor -> CInt -> IO CInt64
|
|
dim = {# call TF_Dim as ^ #}
|
|
|
|
tensorByteSize :: Tensor -> IO CULong
|
|
tensorByteSize = {# call TF_TensorByteSize as ^ #}
|
|
|
|
tensorData :: Tensor -> IO (Ptr ())
|
|
tensorData = {# call TF_TensorData as ^ #}
|
|
|
|
|
|
-- Session Options.
|
|
{# pointer *TF_SessionOptions as SessionOptions newtype #}
|
|
|
|
newSessionOptions :: IO SessionOptions
|
|
newSessionOptions = {# call TF_NewSessionOptions as ^ #}
|
|
|
|
setTarget :: SessionOptions -> CString -> IO ()
|
|
setTarget = {# call TF_SetTarget as ^ #}
|
|
|
|
setConfig :: SessionOptions -> Ptr () -> CULong -> Status -> IO ()
|
|
setConfig = {# call TF_SetConfig as ^ #}
|
|
|
|
deleteSessionOptions :: SessionOptions -> IO ()
|
|
deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
|
|
|
|
|
|
-- Session.
|
|
{# pointer *TF_DeprecatedSession as Session newtype #}
|
|
|
|
newSession :: SessionOptions -> Status -> IO Session
|
|
newSession = {# call TF_NewDeprecatedSession as ^ #}
|
|
|
|
closeSession :: Session -> Status -> IO ()
|
|
closeSession = {# call TF_CloseDeprecatedSession as ^ #}
|
|
|
|
deleteSession :: Session -> Status -> IO ()
|
|
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #}
|
|
|
|
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
|
|
extendGraph = {# call TF_ExtendGraph as ^ #}
|
|
|
|
run :: Session
|
|
-> BufferPtr -- RunOptions proto.
|
|
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
|
|
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
|
|
-> Ptr CString -> CInt -- Target nodes (names, count).
|
|
-> BufferPtr -- RunMetadata proto.
|
|
-> Status
|
|
-> IO ()
|
|
run = {# call TF_Run as ^ #}
|
|
|
|
-- FFI helpers.
|
|
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
|
|
foreign import ccall "wrapper"
|
|
wrapTensorDealloc :: TensorDeallocFn -> IO (FunPtr TensorDeallocFn)
|
|
|
|
|
|
-- | Get the OpList of all OpDefs defined in this address space.
|
|
-- Returns a BufferPtr, ownership of which is transferred to the caller
|
|
-- (and can be freed using deleteBuffer).
|
|
--
|
|
-- The data in the buffer will be the serialized OpList proto for ops registered
|
|
-- in this address space.
|
|
getAllOpList :: IO BufferPtr
|
|
getAllOpList = {# call TF_GetAllOpList as ^ #}
|
|
|
|
foreign import ccall "&TF_DeleteBuffer"
|
|
deleteBuffer :: FunPtr (BufferPtr -> IO ())
|