mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
parent
c99a23b6a7
commit
fdbfd050f8
2 changed files with 20 additions and 2 deletions
|
@ -203,9 +203,15 @@ matTranspose' params = flip (CoreOps.transpose' params) (vector [1, 0 :: Int32])
|
||||||
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
placeholder :: (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
|
||||||
placeholder = placeholder' id
|
placeholder = placeholder' id
|
||||||
|
|
||||||
placeholder' :: (MonadBuild m, TensorType a) => OpParams -> Shape -> m (Tensor Value a)
|
placeholder' :: forall m a . (MonadBuild m, TensorType a)
|
||||||
|
=> OpParams -> Shape -> m (Tensor Value a)
|
||||||
placeholder' params pShape
|
placeholder' params pShape
|
||||||
= render $ CoreOps.placeholder' (params . (opAttr "shape" .~ pShape))
|
-- Note: we don't use CoreOps.placeholder' since that op isn't stateful,
|
||||||
|
-- and thus would be CSE'd.
|
||||||
|
= build $ buildOp $ opDef "Placeholder"
|
||||||
|
& opAttr "dtype" .~ tensorType (undefined :: a)
|
||||||
|
& opAttr "shape" .~ pShape
|
||||||
|
& params
|
||||||
|
|
||||||
-- | Creates a variable initialized to the given value.
|
-- | Creates a variable initialized to the given value.
|
||||||
-- Initialization happens next time session runs.
|
-- Initialization happens next time session runs.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
-- See the License for the specific language governing permissions and
|
-- See the License for the specific language governing permissions and
|
||||||
-- limitations under the License.
|
-- limitations under the License.
|
||||||
|
|
||||||
|
{-# LANGUAGE OverloadedLists #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module Main where
|
module Main where
|
||||||
|
@ -68,9 +69,20 @@ testSaveRestore = testCase "testSaveRestore" $
|
||||||
TF.run v
|
TF.run v
|
||||||
liftIO $ TF.Scalar 134 @=? result
|
liftIO $ TF.Scalar 134 @=? result
|
||||||
|
|
||||||
|
-- | Test that 'placeholder' is not CSE'd.
|
||||||
|
testPlaceholderCse :: Test
|
||||||
|
testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
||||||
|
p1 <- TF.placeholder []
|
||||||
|
p2 <- TF.placeholder []
|
||||||
|
let enc :: Float -> TF.TensorData Float
|
||||||
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||||
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||||
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testSaveRestore
|
main = googleTest [ testSaveRestore
|
||||||
, testSize
|
, testSize
|
||||||
, testReducedShape
|
, testReducedShape
|
||||||
|
, testPlaceholderCse
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue