1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 11:03:34 +02:00

Reduce code duplication

This commit is contained in:
fkm3 2016-12-14 00:24:11 -08:00
parent d6942a6537
commit e8387b50f3

View File

@ -210,74 +210,34 @@ instance TensorDataType S.Vector Float where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Float where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Double where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Double where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Int8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Int8 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Int16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Int16 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Int32 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Int32 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Int64 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Int64 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Word8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Word8 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType S.Vector Word16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType V.Vector Word16 where
decodeTensorData = S.convert . simpleDecode
encodeTensorData x = simpleEncode x . S.convert
instance TensorDataType V.Vector (Complex Float) where
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance TensorDataType V.Vector (Complex Double) where
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
@ -288,13 +248,20 @@ instance TensorDataType S.Vector Bool where
where
fromBool x = if x then 1 else 0 :: Word8
instance TensorDataType V.Vector Bool where
decodeTensorData =
(S.convert :: S.Vector Bool -> V.Vector Bool) . decodeTensorData
encodeTensorData x =
encodeTensorData x . (S.convert :: V.Vector Bool -> S.Vector Bool)
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
=> TensorDataType V.Vector a where
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
instance TensorDataType V.Vector ByteString where
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset: