tensorflow-haskell/tensorflow/src/TensorFlow/Types.hs

618 lines
20 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-- We use UndecidableInstances for type families with recursive definitions
-- like "\\". Those instances will terminate since each equation unwraps one
-- cons cell of a type-level list.
{-# LANGUAGE UndecidableInstances #-}
module TensorFlow.Types
( TensorType(..)
, TensorData(..)
, TensorDataType(..)
, Scalar(..)
, Shape(..)
, protoShape
, Attribute(..)
, DataType(..)
, ResourceHandle
, Variant
-- * Lists
, ListOf(..)
, List
, (/:/)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
, fromTensorTypeList
, fromTensorTypes
-- * Type constraints
, OneOf
, type (/=)
, OneOfs
-- ** Implementation of constraints
, TypeError
, ExcludedCase
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.ProtoLens.Message(defMessage)
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.ProtoLens.TextFormat (showMessageShort)
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word32, Word64)
import Foreign.Storable (Storable)
import GHC.Exts (Constraint, IsList(..))
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
import Lens.Family2.Unchecked (adapter)
import Text.Printf (printf)
import qualified Data.Attoparsec.ByteString as Atto
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Lazy as L
import qualified Data.Vector as V
import qualified Data.Vector.Storable as S
import Proto.Tensorflow.Core.Framework.AttrValue
( AttrValue
, AttrValue'ListValue
)
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
( b
, f
, i
, s
, list
, type'
, shape
, tensor
)
import Proto.Tensorflow.Core.Framework.ResourceHandle
(ResourceHandleProto)
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
(TensorProto)
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
( boolVal
, doubleVal
, floatVal
, intVal
, int64Val
, resourceHandleVal
, stringVal
, uint32Val
, uint64Val
)
import Proto.Tensorflow.Core.Framework.TensorShape
(TensorShapeProto)
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
( dim
, size
, unknownRank
)
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
import qualified TensorFlow.Internal.FFI as FFI
type ResourceHandle = ResourceHandleProto
-- | Dynamic type.
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
-- simplify op generation.
data Variant
-- | The class of scalar types supported by tensorflow.
class TensorType a where
tensorType :: a -> DataType
tensorRefType :: a -> DataType
tensorVal :: Lens' TensorProto [a]
instance TensorType Float where
tensorType _ = DT_FLOAT
tensorRefType _ = DT_FLOAT_REF
tensorVal = floatVal
instance TensorType Double where
tensorType _ = DT_DOUBLE
tensorRefType _ = DT_DOUBLE_REF
tensorVal = doubleVal
instance TensorType Int32 where
tensorType _ = DT_INT32
tensorRefType _ = DT_INT32_REF
tensorVal = intVal
instance TensorType Int64 where
tensorType _ = DT_INT64
tensorRefType _ = DT_INT64_REF
tensorVal = int64Val
integral :: Integral a => Lens' [Int32] [a]
integral = under (adapter (fmap fromIntegral) (fmap fromIntegral))
instance TensorType Word8 where
tensorType _ = DT_UINT8
tensorRefType _ = DT_UINT8_REF
tensorVal = intVal . integral
instance TensorType Word16 where
tensorType _ = DT_UINT16
tensorRefType _ = DT_UINT16_REF
tensorVal = intVal . integral
instance TensorType Word32 where
tensorType _ = DT_UINT32
tensorRefType _ = DT_UINT32_REF
tensorVal = uint32Val
instance TensorType Word64 where
tensorType _ = DT_UINT64
tensorRefType _ = DT_UINT64_REF
tensorVal = uint64Val
instance TensorType Int16 where
tensorType _ = DT_INT16
tensorRefType _ = DT_INT16_REF
tensorVal = intVal . integral
instance TensorType Int8 where
tensorType _ = DT_INT8
tensorRefType _ = DT_INT8_REF
tensorVal = intVal . integral
instance TensorType ByteString where
tensorType _ = DT_STRING
tensorRefType _ = DT_STRING_REF
tensorVal = stringVal
instance TensorType Bool where
tensorType _ = DT_BOOL
tensorRefType _ = DT_BOOL_REF
tensorVal = boolVal
instance TensorType (Complex Float) where
tensorType _ = DT_COMPLEX64
tensorRefType _ = DT_COMPLEX64
tensorVal = error "TODO (Complex Float)"
instance TensorType (Complex Double) where
tensorType _ = DT_COMPLEX128
tensorRefType _ = DT_COMPLEX128
tensorVal = error "TODO (Complex Double)"
instance TensorType ResourceHandle where
tensorType _ = DT_RESOURCE
tensorRefType _ = DT_RESOURCE_REF
tensorVal = resourceHandleVal
instance TensorType Variant where
tensorType _ = DT_VARIANT
tensorRefType _ = DT_VARIANT_REF
tensorVal = error "TODO Variant"
-- | Tensor data with the correct memory layout for tensorflow.
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
-- | Types that can be converted to and from 'TensorData'.
--
-- 'S.Vector' is the most efficient to encode/decode for most element types.
class TensorType a => TensorDataType s a where
-- | Decode the bytes of a 'TensorData' into an 's'.
decodeTensorData :: TensorData a -> s a
-- | Encode an 's' into a 'TensorData'.
--
-- The values should be in row major order, e.g.,
--
-- element 0: index (0, ..., 0)
-- element 1: index (0, ..., 1)
-- ...
encodeTensorData :: Shape -> s a -> TensorData a
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
-- can use Vector.Storable to encode/decode by type casting pointers.
-- TODO(fmayle): Assert that the data type matches the return type.
simpleDecode :: Storable a => TensorData a -> S.Vector a
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
simpleEncode :: forall a . (TensorType a, Storable a)
=> Shape -> S.Vector a -> TensorData a
simpleEncode (Shape xs) v =
if product xs /= fromIntegral (S.length v)
then error $ printf
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
(show xs) (product xs) (S.length v)
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
where
dt = tensorType (undefined :: a)
instance TensorDataType S.Vector Float where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Double where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int32 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Int64 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word8 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
instance TensorDataType S.Vector Word16 where
decodeTensorData = simpleDecode
encodeTensorData = simpleEncode
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
-- encoding more expensive. It may make sense to define a custom boolean type.
instance TensorDataType S.Vector Bool where
decodeTensorData =
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
encodeTensorData (Shape xs) =
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
where
fromBool x = if x then 1 else 0 :: Word8
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a, TensorType a)
=> TensorDataType V.Vector a where
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
decodeTensorData = error "TODO (Complex Float)"
encodeTensorData = error "TODO (Complex Float)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
decodeTensorData = error "TODO (Complex Double)"
encodeTensorData = error "TODO (Complex Double)"
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
-- table offsets for each element :: [Word64]
-- at each element offset:
-- string length :: VarInt64
-- string data :: [Word8]
decodeTensorData tensorData =
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
if expected /= count
then Left $ "decodeTensorData for ByteString count mismatch " ++
show (expected, count)
else V.mapM decodeString (S.convert offsets)
where
expected = S.length offsets
count = fromIntegral $ product $ FFI.tensorDataDimensions
$ unTensorData tensorData
bytes = FFI.tensorDataBytes $ unTensorData tensorData
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
decodeString :: Word64 -> Either String ByteString
decodeString offset =
let stringDataStart = B.drop (fromIntegral offset) dataBytes
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
stringParser :: Atto.Parser ByteString
stringParser = getVarInt >>= Atto.take . fromIntegral
encodeTensorData (Shape xs) vec =
TensorData $ FFI.TensorData xs dt byteVector
where
dt = tensorType (undefined :: ByteString)
-- Add a string to an offset table and data blob.
addString :: (Builder, Builder, Word64)
-> ByteString
-> (Builder, Builder, Word64)
addString (table, strings, offset) str =
( table <> Builder.word64LE offset
, strings <> lengthBytes <> Builder.byteString str
, offset + lengthBytesLen + strLen
)
where
strLen = fromIntegral $ B.length str
lengthBytes = putVarInt $ fromIntegral $ B.length str
lengthBytesLen =
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
-- Encode all strings.
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
-- Concat offset table with data.
bytes = table' <> strings'
-- Convert to Vector Word8.
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
newtype Scalar a = Scalar {unScalar :: a}
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
RealFrac, IsString)
instance (TensorDataType V.Vector a, TensorType a) => TensorDataType Scalar a where
decodeTensorData = Scalar . headFromSingleton . decodeTensorData
encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
headFromSingleton :: V.Vector a -> a
headFromSingleton x
| V.length x == 1 = V.head x
| otherwise = error $
"Unable to extract singleton from tensor of length "
++ show (V.length x)
-- | Shape (dimensions) of a tensor.
--
-- TensorFlow supports shapes of unknown rank, which are represented as
-- @Nothing :: Maybe Shape@ in Haskell.
newtype Shape = Shape [Int64] deriving Show
instance IsList Shape where
type Item Shape = Int64
fromList = Shape . fromList
toList (Shape ss) = toList ss
protoShape :: Lens' TensorShapeProto Shape
protoShape = under (adapter protoToShape shapeToProto)
where
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
++ showMessageShort p
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
protoMaybeShape = under (adapter protoToShape shapeToProto)
where
protoToShape :: TensorShapeProto -> Maybe Shape
protoToShape p =
if view unknownRank p
then Nothing
else Just (Shape (p ^.. dim . traverse . size))
shapeToProto :: Maybe Shape -> TensorShapeProto
shapeToProto Nothing =
defMessage & unknownRank .~ True
shapeToProto (Just (Shape ds)) =
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
class Attribute a where
attrLens :: Lens' AttrValue a
instance Attribute Float where
attrLens = f
instance Attribute ByteString where
attrLens = s
instance Attribute Int64 where
attrLens = i
instance Attribute DataType where
attrLens = type'
instance Attribute TensorProto where
attrLens = tensor
instance Attribute Bool where
attrLens = b
instance Attribute Shape where
attrLens = shape . protoShape
instance Attribute (Maybe Shape) where
attrLens = shape . protoMaybeShape
-- TODO(gnezdo): support generating list(Foo) from [Foo].
instance Attribute AttrValue'ListValue where
attrLens = list
instance Attribute [DataType] where
attrLens = list . type'
instance Attribute [Int64] where
attrLens = list . i
-- | A heterogeneous list type.
data ListOf f as where
Nil :: ListOf f '[]
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
infixr 5 :/
type family All f as :: Constraint where
All f '[] = ()
All f (a ': as) = (f a, All f as)
type family Map f as where
Map f '[] = '[]
Map f (a ': as) = f a ': Map f as
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == Nil = True
(x :/ xs) == (y :/ ys) = x == y && xs == ys
-- Newer versions of GHC use the GADT to tell that the previous cases are
-- exhaustive.
#if __GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec _ Nil = showString "Nil"
showsPrec d (x :/ xs) = showParen (d > 10)
$ showsPrec 6 x . showString " :/ "
. showsPrec 6 xs
type List = ListOf Identity
-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
(/:/) = (:/) . Identity
infixr 5 /:/
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
-- natural representation as a conjunction, i.e.,
--
-- @
-- a == Double || a == Float
-- @
--
-- into a disjunction like
--
-- @
-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
-- @
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
-- Assert `TensorTypes' ts` to make error messages a little better.
= (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)
type OneOfs ts as = (TensorTypes as, TensorTypes' ts,
NoneOfs (AllTensorTypes \\ ts) as)
type family NoneOfs ts as :: Constraint where
NoneOfs ts '[] = ()
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
data TensorTypeProxy a where
TensorTypeProxy :: TensorType a => TensorTypeProxy a
type TensorTypeList = ListOf TensorTypeProxy
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
= tensorType (undefined :: t) : fromTensorTypeList ts
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
class TensorTypes (ts :: [*]) where
tensorTypes :: TensorTypeList ts
instance TensorTypes '[] where
tensorTypes = Nil
-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
tensorTypes = TensorTypeProxy :/ tensorTypes
-- | A simpler version of the 'TensorTypes' class, that doesn't run
-- afoul of @-Wsimplifiable-class-constraints@.
--
-- In more detail: the constraint @OneOf '[Double, Float] a@ leads
-- to the constraint @TensorTypes' '[Double, Float]@, as a safety-check
-- to give better error messages. However, if @TensorTypes'@ were a class,
-- then GHC 8.2.1 would complain with the above warning unless @NoMonoBinds@
-- were enabled. So instead, we use a separate type family for this purpose.
-- For more details: https://ghc.haskell.org/trac/ghc/ticket/11948
type family TensorTypes' (ts :: [*]) :: Constraint where
-- Specialize this type family when `ts` is a long list, to avoid deeply
-- nested tuples of constraints. Works around a bug in ghc-8.0:
-- https://ghc.haskell.org/trac/ghc/ticket/12175
TensorTypes' (t1 ': t2 ': t3 ': t4 ': ts)
= (TensorType t1, TensorType t2, TensorType t3, TensorType t4
, TensorTypes' ts)
TensorTypes' (t1 ': t2 ': t3 ': ts)
= (TensorType t1, TensorType t2, TensorType t3, TensorTypes' ts)
TensorTypes' (t1 ': t2 ': ts)
= (TensorType t1, TensorType t2, TensorTypes' ts)
TensorTypes' (t ': ts) = (TensorType t, TensorTypes' ts)
TensorTypes' '[] = ()
-- | A constraint checking that two types are different.
type family a /= b :: Constraint where
a /= a = TypeError a ~ ExcludedCase
a /= b = ()
-- | Helper types to produce a reasonable type error message when the Constraint
-- "a /= a" fails.
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
data TypeError a
data ExcludedCase
-- | An enumeration of all valid 'TensorType's.
type AllTensorTypes =
-- NOTE: This list should be kept in sync with
-- TensorFlow.OpGen.dtTypeToHaskell.
-- TODO: Add support for Complex Float/Double.
'[ Float
, Double
, Int8
, Int16
, Int32
, Int64
, Word8
, Word16
, ByteString
, Bool
]
-- | Removes a type from the given list of types.
type family Delete a as where
Delete a '[] = '[]
Delete a (a ': as) = Delete a as
Delete a (b ': as) = b ': Delete a as
-- | Takes the difference of two lists of types.
type family as \\ bs where
as \\ '[] = as
as \\ (b ': bs) = Delete b as \\ bs
-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
type family NoneOf ts a :: Constraint where
-- Specialize this type family when `ts` is a long list, to avoid deeply
-- nested tuples of constraints. Works around a bug in ghc-8.0:
-- https://ghc.haskell.org/trac/ghc/ticket/12175
NoneOf (t1 ': t2 ': t3 ': t4 ': ts) a
= (a /= t1, a /= t2, a /= t3, a /= t4, NoneOf ts a)
NoneOf (t1 ': t2 ': t3 ': ts) a = (a /= t1, a /= t2, a /= t3, NoneOf ts a)
NoneOf (t1 ': t2 ': ts) a = (a /= t1, a /= t2, NoneOf ts a)
NoneOf (t1 ': ts) a = (a /= t1, NoneOf ts a)
NoneOf '[] a = ()