diff --git a/tensorflow-ops/tests/OpsTest.hs b/tensorflow-ops/tests/OpsTest.hs index de35828..4824d6f 100644 --- a/tensorflow-ops/tests/OpsTest.hs +++ b/tensorflow-ops/tests/OpsTest.hs @@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do $ p1 `TF.add` p2 liftIO $ result @=? TF.Scalar 5 +-- | See https://github.com/tensorflow/haskell/issues/92. +-- Even though we're not explicitly evaluating `f0` until the end, +-- it should hold the earlier value of the variable. +testRereadRef :: Test +testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do + w <- TF.initializedVariable 0 + f0 <- TF.run w + TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float)) + f1 <- TF.run w + liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1) + main :: IO () main = googleTest [ testSaveRestore , testSize , testReducedShape , testPlaceholderCse , testScalarFeedCse + , testRereadRef ] diff --git a/tensorflow/src/TensorFlow/Internal/FFI.hs b/tensorflow/src/TensorFlow/Internal/FFI.hs index 03efe4d..59a79b3 100644 --- a/tensorflow/src/TensorFlow/Internal/FFI.hs +++ b/tensorflow/src/TensorFlow/Internal/FFI.hs @@ -42,7 +42,7 @@ import Data.Typeable (Typeable) import Data.Word (Word8) import Foreign (Ptr, FunPtr, nullPtr, castPtr) import Foreign.C.String (CString) -import Foreign.ForeignPtr (newForeignPtr, withForeignPtr) +import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr) import Foreign.Marshal.Alloc (free) import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray) import System.IO.Unsafe (unsafePerformIO) @@ -51,7 +51,7 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Encoding.Error as T import qualified Data.Vector.Storable as S -import qualified Foreign.Concurrent as ForeignC +import qualified Data.Vector.Storable.Mutable as M import Data.ProtoLens (Message, encodeMessage) import Proto.Tensorflow.Core.Framework.Graph (GraphDef) @@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x -- | Create a TensorData from a Raw.Tensor. -- -- Takes ownership of the Raw.Tensor. +-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it), +-- since the raw pointer may refer to storage inside a mutable TensorFlow +-- variable. We should avoid that copy when it's not needed; for example, +-- by making TensorData wrap an IOVector, and changing the code that uses it. createTensorData :: Raw.Tensor -> IO TensorData createTensorData t = do -- Read dimensions. @@ -203,8 +207,11 @@ createTensorData t = do -- Read data. len <- safeConvert <$> Raw.tensorByteSize t bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8) - fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t) - let v = S.unsafeFromForeignPtr0 fp len + fp <- newForeignPtr_ bytes + -- Make an explicit copy of the raw data, since it might point + -- to a mutable variable's memory. + v <- S.freeze (M.unsafeFromForeignPtr0 fp len) + Raw.deleteTensor t return $ TensorData (map safeConvert dims) dtype v -- | Runs the given action which does FFI calls updating a provided