Fix TensorData encode and decode for Bool (#49)

This commit is contained in:
fkm3 2016-12-12 19:40:32 -08:00 committed by GitHub
parent cc08520dc7
commit 91f508eb5c
3 changed files with 18 additions and 6 deletions

View File

@ -197,6 +197,7 @@ Test-Suite TypesTest
, lens-family , lens-family
, google-shim , google-shim
, tensorflow , tensorflow
, tensorflow-core-ops
, tensorflow-ops , tensorflow-ops
, tensorflow-proto , tensorflow-proto
, transformers , transformers

View File

@ -37,6 +37,7 @@ import qualified Data.ByteString.Char8 as B8
import qualified Data.Vector as V import qualified Data.Vector as V
import qualified TensorFlow.ControlFlow as TF import qualified TensorFlow.ControlFlow as TF
import qualified TensorFlow.GenOps.Core as TF (select)
import qualified TensorFlow.Ops as TF import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Session as TF import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF import qualified TensorFlow.Tensor as TF
@ -52,17 +53,22 @@ testFFIRoundTrip = testCase "testFFIRoundTrip" $
TF.runSession $ do TF.runSession $ do
let floatData = V.fromList [1..6 :: Float] let floatData = V.fromList [1..6 :: Float]
stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]] stringData = V.fromList [B8.pack (show x) | x <- [1..6::Integer]]
boolData = V.fromList [True, True, False, True, False, False]
f <- TF.build $ TF.placeholder [2,3] f <- TF.build $ TF.placeholder [2,3]
s <- TF.build $ TF.placeholder [2,3] s <- TF.build $ TF.placeholder [2,3]
b <- TF.build $ TF.placeholder [2,3]
let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData) let feeds = [ TF.feed f (TF.encodeTensorData [2,3] floatData)
, TF.feed s (TF.encodeTensorData [2,3] stringData) , TF.feed s (TF.encodeTensorData [2,3] stringData)
, TF.feed b (TF.encodeTensorData [2,3] boolData)
] ]
-- It is an error to fetch a tensor that is being fed, so the tensors -- Do something idempotent to the tensors to verify that tensorflow can
-- are passed through identity. -- handle the encoding. Originally this used `TF.identity`, but that
(f', s') <- TF.runWithFeeds feeds (TF.identity f, TF.identity s) -- wasn't enough to catch a bug in the encoding of Bool.
(f', s', b') <- TF.runWithFeeds feeds (f+0, TF.identity s, TF.select b b b)
liftIO $ do liftIO $ do
floatData @=? f' floatData @=? f'
stringData @=? s' stringData @=? s'
boolData @=? b'
data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a) data TensorDataInputs a = TensorDataInputs [Int64] (V.Vector a)

View File

@ -241,13 +241,18 @@ instance TensorType ByteString where
-- Convert to Vector Word8. -- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorType Bool where instance TensorType Bool where
tensorType _ = DT_BOOL tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal tensorVal = boolVal
decodeTensorData = simpleDecode decodeTensorData =
encodeTensorData = simpleEncode S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance TensorType (Complex Float) where instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64 tensorType _ = DT_COMPLEX64