mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
Add another test of CSE and feeds. (#87)
As a follow-up to #86, check that our CSE isn't too aggressive to prevent feeds of pure ops with distinct names.
This commit is contained in:
parent
fdbfd050f8
commit
a11a417ad5
1 changed files with 14 additions and 0 deletions
|
@ -79,10 +79,24 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
||||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||||
liftIO $ result @=? TF.Scalar 5
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
|
-- | Test that regular tensors can also be used for feeds, as long as they each
|
||||||
|
-- have a different name.
|
||||||
|
testScalarFeedCse :: Test
|
||||||
|
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
||||||
|
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
|
||||||
|
-- The second op is identical to the first other than its name; make sure
|
||||||
|
-- we don't aggressively CSE them together and prevent feeding them
|
||||||
|
-- separately.
|
||||||
|
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
||||||
|
let enc :: Float -> TF.TensorData Float
|
||||||
|
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||||
|
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||||
|
liftIO $ result @=? TF.Scalar 5
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = googleTest [ testSaveRestore
|
main = googleTest [ testSaveRestore
|
||||||
, testSize
|
, testSize
|
||||||
, testReducedShape
|
, testReducedShape
|
||||||
, testPlaceholderCse
|
, testPlaceholderCse
|
||||||
|
, testScalarFeedCse
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue