mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-23 03:19:44 +01:00
Add another test of CSE and feeds. (#87)
As a follow-up to #86, check that our CSE isn't too aggressive to prevent feeds of pure ops with distinct names.
This commit is contained in:
parent
fdbfd050f8
commit
a11a417ad5
1 changed files with 14 additions and 0 deletions
|
@ -79,10 +79,24 @@ testPlaceholderCse = testCase "testPlaceholderCse" $ TF.runSession $ do
|
|||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
-- | Test that regular tensors can also be used for feeds, as long as they each
|
||||
-- have a different name.
|
||||
testScalarFeedCse :: Test
|
||||
testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
|
||||
p1 <- TF.render $ TF.scalar' (TF.opName .~ "A") 0
|
||||
-- The second op is identical to the first other than its name; make sure
|
||||
-- we don't aggressively CSE them together and prevent feeding them
|
||||
-- separately.
|
||||
p2 <- TF.render $ TF.scalar' (TF.opName .~ "B") 0
|
||||
let enc :: Float -> TF.TensorData Float
|
||||
enc n = TF.encodeTensorData [] (V.fromList [n])
|
||||
result <- TF.runWithFeeds [TF.feed p1 (enc 2), TF.feed p2 (enc 3)] $ p1 + p2
|
||||
liftIO $ result @=? TF.Scalar 5
|
||||
|
||||
main :: IO ()
|
||||
main = googleTest [ testSaveRestore
|
||||
, testSize
|
||||
, testReducedShape
|
||||
, testPlaceholderCse
|
||||
, testScalarFeedCse
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue