mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Use safe version of fromIntegral in FFI code
This commit is contained in:
parent
5616b1b008
commit
058364a634
|
@ -14,6 +14,7 @@
|
|||
|
||||
{-# LANGUAGE DeriveDataTypeable #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module TensorFlow.Internal.FFI
|
||||
( TensorFlowException(..)
|
||||
|
@ -34,7 +35,10 @@ import Control.Concurrent.Async (Async, async, cancel, waitCatch)
|
|||
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
|
||||
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
|
||||
import Control.Monad (when)
|
||||
import Data.Bits (Bits, toIntegralSized)
|
||||
import Data.Data (Data, dataTypeName, dataTypeOf)
|
||||
import Data.Int (Int64)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Typeable (Typeable)
|
||||
import Data.Word (Word8)
|
||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||
|
@ -134,9 +138,9 @@ run session feeds fetches targets = do
|
|||
checkStatus $ Raw.run
|
||||
session
|
||||
nullPtr
|
||||
feedNames cFeedTensors (fromIntegral feedsLen)
|
||||
fetchNames tensorOuts (fromIntegral fetchesLen)
|
||||
ctargets (fromIntegral targetsLen)
|
||||
feedNames cFeedTensors (safeConvert feedsLen)
|
||||
fetchNames tensorOuts (safeConvert fetchesLen)
|
||||
ctargets (safeConvert targetsLen)
|
||||
nullPtr
|
||||
outTensors <- peekArray fetchesLen tensorOuts
|
||||
mapM createTensorData outTensors
|
||||
|
@ -145,6 +149,17 @@ run session feeds fetches targets = do
|
|||
-- Internal.
|
||||
|
||||
|
||||
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
|
||||
safeConvert ::
|
||||
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
|
||||
=> a -> b
|
||||
safeConvert x =
|
||||
fromMaybe
|
||||
(error ("Failed to convert " ++ show x ++ ", got " ++
|
||||
show (fromIntegral x :: b)))
|
||||
(toIntegralSized x)
|
||||
|
||||
|
||||
-- | Use a list of ByteString as a list of CString.
|
||||
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
|
||||
withStringList strings fn = go strings []
|
||||
|
@ -162,13 +177,13 @@ withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
|
|||
-- | Create a Raw.Tensor from a TensorData.
|
||||
createRawTensor :: TensorData -> IO Raw.Tensor
|
||||
createRawTensor (TensorData dims dt byteVec) =
|
||||
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
|
||||
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
|
||||
let len = S.length byteVec
|
||||
dest <- mallocArray len
|
||||
S.unsafeWith byteVec $ \x -> copyArray dest x len
|
||||
Raw.newTensor (toEnum $ fromEnum dt)
|
||||
cdims (fromIntegral cdimsLen)
|
||||
(castPtr dest) (fromIntegral len)
|
||||
cdims (safeConvert cdimsLen)
|
||||
(castPtr dest) (safeConvert len)
|
||||
tensorDeallocFunPtr nullPtr
|
||||
|
||||
{-# NOINLINE tensorDeallocFunPtr #-}
|
||||
|
@ -186,7 +201,7 @@ createTensorData t = do
|
|||
-- Read type.
|
||||
dtype <- toEnum . fromEnum <$> Raw.tensorType t
|
||||
-- Read data.
|
||||
len <- fromIntegral <$> Raw.tensorByteSize t
|
||||
len <- safeConvert <$> Raw.tensorByteSize t
|
||||
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
||||
-- Note: We would like to avoid the copy by creating an S.Vector directly
|
||||
-- from 'bytes' and calling Raw.deleteTensor when it gets GC'd, but we can't
|
||||
|
@ -196,7 +211,7 @@ createTensorData t = do
|
|||
v <- S.unsafeFreeze mv
|
||||
-- Free tensor.
|
||||
Raw.deleteTensor t
|
||||
return $ TensorData (map fromIntegral dims) dtype v
|
||||
return $ TensorData (map safeConvert dims) dtype v
|
||||
|
||||
-- | Runs the given action which does FFI calls updating a provided
|
||||
-- status object. If the status is not OK it is thrown as
|
||||
|
@ -222,10 +237,10 @@ setSessionTarget target = B.useAsCString target . Raw.setTarget
|
|||
|
||||
-- | Serializes the given msg and provides it as (ptr,len) argument
|
||||
-- to the given action.
|
||||
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
|
||||
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
|
||||
msg -> (Ptr b -> c -> IO a) -> IO a
|
||||
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
|
||||
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
|
||||
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
|
||||
|
||||
-- | Returns the serialized OpList of all OpDefs defined in this
|
||||
-- address space.
|
||||
|
@ -238,7 +253,7 @@ getAllOpList = do
|
|||
withForeignPtr foreignPtr $
|
||||
\ptr -> B.packCStringLen =<< (,)
|
||||
<$> (castPtr <$> Raw.getBufferData ptr)
|
||||
<*> (fromIntegral <$> Raw.getBufferLength ptr)
|
||||
<*> (safeConvert <$> Raw.getBufferLength ptr)
|
||||
where
|
||||
checkCall = do
|
||||
p <- Raw.getAllOpList
|
||||
|
|
Loading…
Reference in New Issue
Block a user