mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 11:03:34 +02:00
Reduce code duplication
This commit is contained in:
parent
d6942a6537
commit
e8387b50f3
|
@ -210,74 +210,34 @@ instance TensorDataType S.Vector Float where
|
|||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Float where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Double where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Double where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Int8 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Int8 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Int16 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Int16 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Int32 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Int32 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Int64 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Int64 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Word8 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Word8 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType S.Vector Word16 where
|
||||
decodeTensorData = simpleDecode
|
||||
encodeTensorData = simpleEncode
|
||||
|
||||
instance TensorDataType V.Vector Word16 where
|
||||
decodeTensorData = S.convert . simpleDecode
|
||||
encodeTensorData x = simpleEncode x . S.convert
|
||||
|
||||
instance TensorDataType V.Vector (Complex Float) where
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
|
||||
instance TensorDataType V.Vector (Complex Double) where
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
||||
-- encoding more expensive. It may make sense to define a custom boolean type.
|
||||
instance TensorDataType S.Vector Bool where
|
||||
|
@ -288,13 +248,20 @@ instance TensorDataType S.Vector Bool where
|
|||
where
|
||||
fromBool x = if x then 1 else 0 :: Word8
|
||||
|
||||
instance TensorDataType V.Vector Bool where
|
||||
decodeTensorData =
|
||||
(S.convert :: S.Vector Bool -> V.Vector Bool) . decodeTensorData
|
||||
encodeTensorData x =
|
||||
encodeTensorData x . (S.convert :: V.Vector Bool -> S.Vector Bool)
|
||||
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a)
|
||||
=> TensorDataType V.Vector a where
|
||||
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
|
||||
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
|
||||
|
||||
instance TensorDataType V.Vector ByteString where
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
|
||||
decodeTensorData = error "TODO (Complex Float)"
|
||||
encodeTensorData = error "TODO (Complex Float)"
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
||||
decodeTensorData = error "TODO (Complex Double)"
|
||||
encodeTensorData = error "TODO (Complex Double)"
|
||||
|
||||
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
||||
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
|
||||
-- table offsets for each element :: [Word64]
|
||||
-- at each element offset:
|
||||
|
|
Loading…
Reference in New Issue
Block a user