module TensorFlow.Types
( TensorType(..)
, TensorData(..)
, Shape(..)
, protoShape
, Attribute(..)
, OneOf
, type (/=)
, TypeError
, ExcludedCase
, TensorTypes
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~))
import Lens.Family2.Unchecked (iso)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue(..)
, AttrValue'ListValue(..)
, b
, f
, i
, s
, list
, type'
, shape
, tensor
)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
( TensorProto(..)
, floatVal
, doubleVal
, intVal
, stringVal
, int64Val
, stringVal
, boolVal
)
import Proto.Tensorflow.Core.Framework.TensorShape
( TensorShapeProto(..)
, dim
, size
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
decodeTensorData :: TensorData a -> V.Vector a
encodeTensorData :: Shape -> V.Vector a -> TensorData a
simpleDecode :: Storable a => TensorData a -> V.Vector a
simpleDecode = S.convert . S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> V.Vector a -> TensorData a
simpleEncode (Shape xs)
= TensorData . FFI.TensorData xs dt . S.unsafeCast . S.convert
where
dt = tensorType (undefined :: a)
instance TensorType Float where
tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Double where
tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int32 where
tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF
tensorVal = intVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int64 where
tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
integral :: Integral a => Lens' [Int32] [a]
integral = iso (fmap fromIntegral) (fmap fromIntegral)
instance TensorType Word8 where
tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Word16 where
tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType Int8 where
tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType ByteString where
tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count
then Left $ "decodeTensorData for ByteString count mismatch " ++
show (expected, count)
else V.mapM decodeString (S.convert offsets)
where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector
where
dt = tensorType (undefined :: ByteString)
addString :: (Builder, Builder, Word64)
-> ByteString
-> (Builder, Builder, Word64)
addString (table, strings, offset) str =
( table <> Builder.word64LE offset
, strings <> lengthBytes <> Builder.byteString str
, offset + lengthBytesLen + strLen
)
where
strLen = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str
lengthBytesLen =
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
bytes = table' <> strings'
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
newtype Shape = Shape [Int64] deriving Show
instance IsList Shape where
type Item Shape = Int64
fromList = Shape . fromList
toList (Shape ss) = toList ss
protoShape :: Lens' TensorShapeProto Shape
protoShape = iso protoToShape shapeToProto
where
protoToShape = Shape . fmap (view size) . view dim
shapeToProto (Shape ds) = def & dim .~ fmap (\d -> def & size .~ d) ds
class Attribute a where
attrLens :: Lens' AttrValue a
instance Attribute Float where
attrLens = f
instance Attribute ByteString where
attrLens = s
instance Attribute Int64 where
attrLens = i
instance Attribute DataType where
attrLens = type'
instance Attribute TensorProto where
attrLens = tensor
instance Attribute Bool where
attrLens = b
instance Attribute Shape where
attrLens = shape . protoShape
instance Attribute AttrValue'ListValue where
attrLens = list
instance Attribute [DataType] where
attrLens = list . type'
instance Attribute [Int64] where
attrLens = list . i
type OneOf ts a
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
type family TensorTypes ts :: Constraint where
TensorTypes '[] = ()
TensorTypes (t ': ts) = (TensorType t, TensorTypes ts)
type family a /= b :: Constraint where
a /= a = TypeError a ~ ExcludedCase
a /= b = ()
data TypeError a
data ExcludedCase
type AllTensorTypes =
'[ Float
, Double
, Int8
, Int16
, Int32
, Int64
, Word8
, Word16
, ByteString
, Bool
]
type family Delete a as where
Delete a '[] = '[]
Delete a (a ': as) = Delete a as
Delete a (b ': as) = b ': Delete a as
type family as \\ bs where
as \\ '[] = as
as \\ (b ': bs) = Delete b as \\ bs
type family NoneOf ts a :: Constraint where
NoneOf '[] a = ()
NoneOf (t ': ts) a = (a /= t, NoneOf ts a)