1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00

Use safe version of fromIntegral in FFI code

This commit is contained in:
Frederick Mayle 2016-11-15 01:28:03 -08:00
parent 5616b1b008
commit 058364a634

View File

@ -14,6 +14,7 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Internal.FFI
( TensorFlowException(..)
@ -34,7 +35,10 @@ import Control.Concurrent.Async (Async, async, cancel, waitCatch)
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
import Control.Exception (Exception, throwIO, bracket, finally, mask_)
import Control.Monad (when)
import Data.Bits (Bits, toIntegralSized)
import Data.Data (Data, dataTypeName, dataTypeOf)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
@ -134,9 +138,9 @@ run session feeds fetches targets = do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (fromIntegral feedsLen)
fetchNames tensorOuts (fromIntegral fetchesLen)
ctargets (fromIntegral targetsLen)
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
nullPtr
outTensors <- peekArray fetchesLen tensorOuts
mapM createTensorData outTensors
@ -145,6 +149,17 @@ run session feeds fetches targets = do
-- Internal.
-- | Same as 'fromIntegral', but throws an error if conversion is "lossy".
safeConvert ::
forall a b. (Show a, Show b, Bits a, Bits b, Integral a, Integral b)
=> a -> b
safeConvert x =
fromMaybe
(error ("Failed to convert " ++ show x ++ ", got " ++
show (fromIntegral x :: b)))
(toIntegralSized x)
-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
@ -162,13 +177,13 @@ withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
withArrayLen (map fromIntegral dims) $ \cdimsLen cdims -> do
withArrayLen (map safeConvert dims) $ \cdimsLen cdims -> do
let len = S.length byteVec
dest <- mallocArray len
S.unsafeWith byteVec $ \x -> copyArray dest x len
Raw.newTensor (toEnum $ fromEnum dt)
cdims (fromIntegral cdimsLen)
(castPtr dest) (fromIntegral len)
cdims (safeConvert cdimsLen)
(castPtr dest) (safeConvert len)
tensorDeallocFunPtr nullPtr
{-# NOINLINE tensorDeallocFunPtr #-}
@ -186,7 +201,7 @@ createTensorData t = do
-- Read type.
dtype <- toEnum . fromEnum <$> Raw.tensorType t
-- Read data.
len <- fromIntegral <$> Raw.tensorByteSize t
len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
-- Note: We would like to avoid the copy by creating an S.Vector directly
-- from 'bytes' and calling Raw.deleteTensor when it gets GC'd, but we can't
@ -196,7 +211,7 @@ createTensorData t = do
v <- S.unsafeFreeze mv
-- Free tensor.
Raw.deleteTensor t
return $ TensorData (map fromIntegral dims) dtype v
return $ TensorData (map safeConvert dims) dtype v
-- | Runs the given action which does FFI calls updating a provided
-- status object. If the status is not OK it is thrown as
@ -222,10 +237,10 @@ setSessionTarget target = B.useAsCString target . Raw.setTarget
-- | Serializes the given msg and provides it as (ptr,len) argument
-- to the given action.
useProtoAsVoidPtrLen :: (Message msg, Num c) =>
useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
msg -> (Ptr b -> c -> IO a) -> IO a
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (fromIntegral len)
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
@ -238,7 +253,7 @@ getAllOpList = do
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (fromIntegral <$> Raw.getBufferLength ptr)
<*> (safeConvert <$> Raw.getBufferLength ptr)
where
checkCall = do
p <- Raw.getAllOpList