1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 11:29:43 +01:00

Work around #92 by always copying TensorData when fetching.

It would be better to avoid the copy when it's not necessary, but
that will require more involved changes to the internal API.  (For example,
Fetchable might need to allow IO or ST actions.)
This commit is contained in:
Judah Jacobson 2017-05-08 17:45:56 -07:00 committed by fkm3
parent 37e3c9b084
commit a64af5076a
2 changed files with 23 additions and 4 deletions

View file

@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
$ p1 `TF.add` p2 $ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5 liftIO $ result @=? TF.Scalar 5
-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
w <- TF.initializedVariable 0
f0 <- TF.run w
TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float))
f1 <- TF.run w
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
main :: IO () main :: IO ()
main = googleTest [ testSaveRestore main = googleTest [ testSaveRestore
, testSize , testSize
, testReducedShape , testReducedShape
, testPlaceholderCse , testPlaceholderCse
, testScalarFeedCse , testScalarFeedCse
, testRereadRef
] ]

View file

@ -42,7 +42,7 @@ import Data.Typeable (Typeable)
import Data.Word (Word8) import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString) import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
@ -51,7 +51,7 @@ import qualified Data.Text as T
import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S import qualified Data.Vector.Storable as S
import qualified Foreign.Concurrent as ForeignC import qualified Data.Vector.Storable.Mutable as M
import Data.ProtoLens (Message, encodeMessage) import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef) import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor. -- | Create a TensorData from a Raw.Tensor.
-- --
-- Takes ownership of the Raw.Tensor. -- Takes ownership of the Raw.Tensor.
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
-- since the raw pointer may refer to storage inside a mutable TensorFlow
-- variable. We should avoid that copy when it's not needed; for example,
-- by making TensorData wrap an IOVector, and changing the code that uses it.
createTensorData :: Raw.Tensor -> IO TensorData createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do createTensorData t = do
-- Read dimensions. -- Read dimensions.
@ -203,8 +207,11 @@ createTensorData t = do
-- Read data. -- Read data.
len <- safeConvert <$> Raw.tensorByteSize t len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t) fp <- newForeignPtr_ bytes
let v = S.unsafeFromForeignPtr0 fp len -- Make an explicit copy of the raw data, since it might point
-- to a mutable variable's memory.
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
Raw.deleteTensor t
return $ TensorData (map safeConvert dims) dtype v return $ TensorData (map safeConvert dims) dtype v
-- | Runs the given action which does FFI calls updating a provided -- | Runs the given action which does FFI calls updating a provided