mirror of
https://github.com/tensorflow/haskell.git
synced 2025-01-24 17:55:02 +01:00
Fix TensorData encode and decode for Bool (#49)
This commit is contained in:
parent
cc08520dc7
commit
91f508eb5c
3 changed files with 18 additions and 6 deletions
|
@ -197,6 +197,7 @@ Test-Suite TypesTest
|
|||
, lens-family
|
||||
, google-shim
|
||||
, tensorflow
|
||||
, tensorflow-core-ops
|
||||
, tensorflow-ops
|
||||
, tensorflow-proto
|
||||
, transformers
|
||||
|
|
|
@ -37,6 +37,7 @@ import qualified Data.ByteString.Char8 as B8
|
|||
import qualified Data.Vector as V
|
||||
|
||||
import qualified TensorFlow.ControlFlow as TF
|
||||
import qualified TensorFlow.GenOps.Core as TF (select)
|
||||
import qualified TensorFlow.Ops as TF
|
||||
import qualified TensorFlow.Session as TF
|
||||
import qualified TensorFlow.Tensor as TF
|
||||
|
@ -52,17 +53,22 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
|
|||
TF.runSession $ do
|
||||
let floatData = V.fromList [1..6 :: Float]
|
||||
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
|
||||
boolData = V.fromList [True, True, False, True, False, False]
|
||||
f <- TF.build $ TF.placeholder [2,3]
|
||||
s <- TF.build $ TF.placeholder [2,3]
|
||||
b <- TF.build $ TF.placeholder [2,3]
|
||||
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
|
||||
, TF.feed s (TF.encodeTensorData [2,3] stringData)
|
||||
, TF.feed b (TF.encodeTensorData [2,3] boolData)
|
||||
]
|
||||
-- It is an error to fetch a tensor that is being fed, so the tensors
|
||||
-- are passed through identity.
|
||||
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s)
|
||||
-- Do something idempotent to the tensors to verify that tensorflow can
|
||||
-- handle the encoding. Originally this used `TF.identity`, but that
|
||||
-- wasn't enough to catch a bug in the encoding of Bool.
|
||||
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
|
||||
liftIO $ do
|
||||
floatData @=? f'
|
||||
stringData @=? s'
|
||||
boolData @=? b'
|
||||
|
||||
|
||||
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)
|
||||
|
|
|
@ -241,13 +241,18 @@ instance TensorType ByteString where
|
|||
-- Convert to Vector Word8.
|
||||
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
||||
|
||||
|
||||
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||
instance TensorType Bool where
|
||||
tensorType _ = DT_BOOL
|
||||
tensorRefType _ = DT_BOOL_REF
|
||||
tensorVal = boolVal
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
decodeTensorData =
|
||||
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
||||
encodeTensorData (Shape xs) =
|
||||
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
||||
where
|
||||
fromBool x = if x then 1 else 0 :: Word8
|
||||
|
||||
instance TensorType (Complex Float) where
|
||||
tensorType _ = DT_COMPLEX64
|
||||
|
|
Loading…
Add table
Reference in a new issue