mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-27 05:19:45 +01:00
Work around #92 by always copying TensorData when fetching.
It would be better to avoid the copy when it's not necessary, but that will require more involved changes to the internal API. (For example, Fetchable might need to allow IO or ST actions.)
This commit is contained in:
parent
37e3c9b084
commit
a64af5076a
2 changed files with 23 additions and 4 deletions
|
@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
||||||
$ p1 `TF.add` p2
|
$ p1 `TF.add` p2
|
||||||
liftIO $ result @=? TF.Scalar 5
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
|
-- | See https://github.com/tensorflow/haskell/issues/92.
|
||||||
|
-- Even though we're not explicitly evaluating `f0` until the end,
|
||||||
|
-- it should hold the earlier value of the variable.
|
||||||
|
testRereadRef :: Test
|
||||||
|
testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
|
||||||
|
w <- TF.initializedVariable 0
|
||||||
|
f0 <- TF.run w
|
||||||
|
TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float))
|
||||||
|
f1 <- TF.run w
|
||||||
|
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testSaveRestore
|
main = googleTest [ testSaveRestore
|
||||||
, testSize
|
, testSize
|
||||||
, testReducedShape
|
, testReducedShape
|
||||||
, testPlaceholderCse
|
, testPlaceholderCse
|
||||||
, testScalarFeedCse
|
, testScalarFeedCse
|
||||||
|
, testRereadRef
|
||||||
]
|
]
|
||||||
|
|
|
@ -42,7 +42,7 @@ import Data.Typeable (Typeable)
|
||||||
import Data.Word (Word8)
|
import Data.Word (Word8)
|
||||||
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
|
||||||
import Foreign.C.String (CString)
|
import Foreign.C.String (CString)
|
||||||
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
|
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
|
||||||
import Foreign.Marshal.Alloc (free)
|
import Foreign.Marshal.Alloc (free)
|
||||||
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
@ -51,7 +51,7 @@ import qualified Data.Text as T
|
||||||
import qualified Data.Text.Encoding as T
|
import qualified Data.Text.Encoding as T
|
||||||
import qualified Data.Text.Encoding.Error as T
|
import qualified Data.Text.Encoding.Error as T
|
||||||
import qualified Data.Vector.Storable as S
|
import qualified Data.Vector.Storable as S
|
||||||
import qualified Foreign.Concurrent as ForeignC
|
import qualified Data.Vector.Storable.Mutable as M
|
||||||
|
|
||||||
import Data.ProtoLens (Message, encodeMessage)
|
import Data.ProtoLens (Message, encodeMessage)
|
||||||
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
|
||||||
|
@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
|
||||||
-- | Create a TensorData from a Raw.Tensor.
|
-- | Create a TensorData from a Raw.Tensor.
|
||||||
--
|
--
|
||||||
-- Takes ownership of the Raw.Tensor.
|
-- Takes ownership of the Raw.Tensor.
|
||||||
|
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
|
||||||
|
-- since the raw pointer may refer to storage inside a mutable TensorFlow
|
||||||
|
-- variable. We should avoid that copy when it's not needed; for example,
|
||||||
|
-- by making TensorData wrap an IOVector, and changing the code that uses it.
|
||||||
createTensorData :: Raw.Tensor -> IO TensorData
|
createTensorData :: Raw.Tensor -> IO TensorData
|
||||||
createTensorData t = do
|
createTensorData t = do
|
||||||
-- Read dimensions.
|
-- Read dimensions.
|
||||||
|
@ -203,8 +207,11 @@ createTensorData t = do
|
||||||
-- Read data.
|
-- Read data.
|
||||||
len <- safeConvert <$> Raw.tensorByteSize t
|
len <- safeConvert <$> Raw.tensorByteSize t
|
||||||
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
|
||||||
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
|
fp <- newForeignPtr_ bytes
|
||||||
let v = S.unsafeFromForeignPtr0 fp len
|
-- Make an explicit copy of the raw data, since it might point
|
||||||
|
-- to a mutable variable's memory.
|
||||||
|
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
|
||||||
|
Raw.deleteTensor t
|
||||||
return $ TensorData (map safeConvert dims) dtype v
|
return $ TensorData (map safeConvert dims) dtype v
|
||||||
|
|
||||||
-- | Runs the given action which does FFI calls updating a provided
|
-- | Runs the given action which does FFI calls updating a provided
|
||||||
|
|
Loading…
Reference in a new issue