629 lines
21 KiB
Haskell
629 lines
21 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE ConstraintKinds #-}
|
|
{-# LANGUAGE CPP #-}
|
|
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE FlexibleInstances #-}
|
|
{-# LANGUAGE GADTs #-}
|
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
{-# LANGUAGE MonoLocalBinds #-}
|
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE RankNTypes #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
{-# LANGUAGE TypeOperators #-}
|
|
-- We use UndecidableInstances for type families with recursive definitions
|
|
-- like "\\". Those instances will terminate since each equation unwraps one
|
|
-- cons cell of a type-level list.
|
|
{-# LANGUAGE UndecidableInstances #-}
|
|
|
|
module TensorFlow.Types
|
|
( TensorType(..)
|
|
, TensorData(..)
|
|
, TensorDataType(..)
|
|
, Scalar(..)
|
|
, Shape(..)
|
|
, protoShape
|
|
, Attribute(..)
|
|
, DataType(..)
|
|
, ResourceHandle
|
|
, Variant
|
|
-- * Lists
|
|
, ListOf(..)
|
|
, List
|
|
, (/:/)
|
|
, TensorTypeProxy(..)
|
|
, TensorTypes(..)
|
|
, TensorTypeList
|
|
, fromTensorTypeList
|
|
, fromTensorTypes
|
|
-- * Type constraints
|
|
, OneOf
|
|
, type (/=)
|
|
, OneOfs
|
|
-- ** Implementation of constraints
|
|
, TypeError
|
|
, ExcludedCase
|
|
, NoneOf
|
|
, type (\\)
|
|
, Delete
|
|
, AllTensorTypes
|
|
) where
|
|
|
|
import Data.Bits (shiftL, (.|.))
|
|
import Data.ProtoLens.Message(defMessage)
|
|
import Data.Functor.Identity (Identity(..))
|
|
import Data.Complex (Complex)
|
|
import Data.Int (Int8, Int16, Int32, Int64)
|
|
import Data.Maybe (fromMaybe)
|
|
import Data.ProtoLens.TextFormat (showMessageShort)
|
|
import Data.Proxy (Proxy(..))
|
|
import Data.String (IsString)
|
|
import Data.Word (Word8, Word16, Word32, Word64)
|
|
import Foreign.Storable (Storable)
|
|
import GHC.Exts (Constraint, IsList(..))
|
|
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
|
|
import Lens.Family2.Unchecked (adapter)
|
|
import Text.Printf (printf)
|
|
import Data.ByteString (ByteString)
|
|
import qualified Data.ByteString as B
|
|
import Data.ByteString.Builder (Builder)
|
|
import qualified Data.ByteString.Builder as Builder
|
|
import qualified Data.ByteString.Lazy as L
|
|
import qualified Data.Vector as V
|
|
import qualified Data.Vector.Storable as S
|
|
import Data.Vector.Split (chunksOf)
|
|
import Proto.Tensorflow.Core.Framework.AttrValue
|
|
( AttrValue
|
|
, AttrValue'ListValue
|
|
)
|
|
import Proto.Tensorflow.Core.Framework.AttrValue_Fields
|
|
( b
|
|
, f
|
|
, i
|
|
, s
|
|
, list
|
|
, type'
|
|
, shape
|
|
, tensor
|
|
)
|
|
|
|
import Proto.Tensorflow.Core.Framework.ResourceHandle
|
|
(ResourceHandleProto)
|
|
import Proto.Tensorflow.Core.Framework.Tensor as Tensor
|
|
(TensorProto)
|
|
import Proto.Tensorflow.Core.Framework.Tensor_Fields as Tensor
|
|
( boolVal
|
|
, doubleVal
|
|
, floatVal
|
|
, intVal
|
|
, int64Val
|
|
, resourceHandleVal
|
|
, stringVal
|
|
, uint32Val
|
|
, uint64Val
|
|
)
|
|
|
|
import Proto.Tensorflow.Core.Framework.TensorShape
|
|
(TensorShapeProto)
|
|
import Proto.Tensorflow.Core.Framework.TensorShape_Fields
|
|
( dim
|
|
, size
|
|
, unknownRank
|
|
)
|
|
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
|
|
|
|
import qualified TensorFlow.Internal.Raw as Raw
|
|
import qualified TensorFlow.Internal.FFI as FFI
|
|
|
|
type ResourceHandle = ResourceHandleProto
|
|
|
|
-- | Dynamic type.
|
|
-- TensorFlow variants aren't supported yet. This type acts a placeholder to
|
|
-- simplify op generation.
|
|
data Variant
|
|
|
|
-- | The class of scalar types supported by tensorflow.
|
|
class TensorType a where
|
|
tensorType :: a -> DataType
|
|
tensorRefType :: a -> DataType
|
|
tensorVal :: Lens' TensorProto [a]
|
|
|
|
instance TensorType Float where
|
|
tensorType _ = DT_FLOAT
|
|
tensorRefType _ = DT_FLOAT_REF
|
|
tensorVal = floatVal
|
|
|
|
instance TensorType Double where
|
|
tensorType _ = DT_DOUBLE
|
|
tensorRefType _ = DT_DOUBLE_REF
|
|
tensorVal = doubleVal
|
|
|
|
instance TensorType Int32 where
|
|
tensorType _ = DT_INT32
|
|
tensorRefType _ = DT_INT32_REF
|
|
tensorVal = intVal
|
|
|
|
instance TensorType Int64 where
|
|
tensorType _ = DT_INT64
|
|
tensorRefType _ = DT_INT64_REF
|
|
tensorVal = int64Val
|
|
|
|
integral :: Integral a => Lens' [Int32] [a]
|
|
integral = under (adapter (fmap fromIntegral) (fmap fromIntegral))
|
|
|
|
instance TensorType Word8 where
|
|
tensorType _ = DT_UINT8
|
|
tensorRefType _ = DT_UINT8_REF
|
|
tensorVal = intVal . integral
|
|
|
|
instance TensorType Word16 where
|
|
tensorType _ = DT_UINT16
|
|
tensorRefType _ = DT_UINT16_REF
|
|
tensorVal = intVal . integral
|
|
|
|
instance TensorType Word32 where
|
|
tensorType _ = DT_UINT32
|
|
tensorRefType _ = DT_UINT32_REF
|
|
tensorVal = uint32Val
|
|
|
|
instance TensorType Word64 where
|
|
tensorType _ = DT_UINT64
|
|
tensorRefType _ = DT_UINT64_REF
|
|
tensorVal = uint64Val
|
|
|
|
instance TensorType Int16 where
|
|
tensorType _ = DT_INT16
|
|
tensorRefType _ = DT_INT16_REF
|
|
tensorVal = intVal . integral
|
|
|
|
instance TensorType Int8 where
|
|
tensorType _ = DT_INT8
|
|
tensorRefType _ = DT_INT8_REF
|
|
tensorVal = intVal . integral
|
|
|
|
instance TensorType ByteString where
|
|
tensorType _ = DT_STRING
|
|
tensorRefType _ = DT_STRING_REF
|
|
tensorVal = stringVal
|
|
|
|
instance TensorType Bool where
|
|
tensorType _ = DT_BOOL
|
|
tensorRefType _ = DT_BOOL_REF
|
|
tensorVal = boolVal
|
|
|
|
instance TensorType (Complex Float) where
|
|
tensorType _ = DT_COMPLEX64
|
|
tensorRefType _ = DT_COMPLEX64
|
|
tensorVal = error "TODO (Complex Float)"
|
|
|
|
instance TensorType (Complex Double) where
|
|
tensorType _ = DT_COMPLEX128
|
|
tensorRefType _ = DT_COMPLEX128
|
|
tensorVal = error "TODO (Complex Double)"
|
|
|
|
instance TensorType ResourceHandle where
|
|
tensorType _ = DT_RESOURCE
|
|
tensorRefType _ = DT_RESOURCE_REF
|
|
tensorVal = resourceHandleVal
|
|
|
|
instance TensorType Variant where
|
|
tensorType _ = DT_VARIANT
|
|
tensorRefType _ = DT_VARIANT_REF
|
|
tensorVal = error "TODO Variant"
|
|
|
|
-- | Tensor data with the correct memory layout for tensorflow.
|
|
newtype TensorData a = TensorData { unTensorData :: FFI.TensorData }
|
|
|
|
-- | Types that can be converted to and from 'TensorData'.
|
|
--
|
|
-- 'S.Vector' is the most efficient to encode/decode for most element types.
|
|
class TensorType a => TensorDataType s a where
|
|
-- | Decode the bytes of a 'TensorData' into an 's'.
|
|
decodeTensorData :: TensorData a -> s a
|
|
-- | Encode an 's' into a 'TensorData'.
|
|
--
|
|
-- The values should be in row major order, e.g.,
|
|
--
|
|
-- element 0: index (0, ..., 0)
|
|
-- element 1: index (0, ..., 1)
|
|
-- ...
|
|
encodeTensorData :: Shape -> s a -> TensorData a
|
|
|
|
-- All types, besides ByteString and Bool, are encoded as simple arrays and we
|
|
-- can use Vector.Storable to encode/decode by type casting pointers.
|
|
|
|
-- TODO(fmayle): Assert that the data type matches the return type.
|
|
simpleDecode :: Storable a => TensorData a -> S.Vector a
|
|
simpleDecode = S.unsafeCast . FFI.tensorDataBytes . unTensorData
|
|
|
|
simpleEncode :: forall a . (TensorType a, Storable a)
|
|
=> Shape -> S.Vector a -> TensorData a
|
|
simpleEncode (Shape xs) v =
|
|
if product xs /= fromIntegral (S.length v)
|
|
then error $ printf
|
|
"simpleEncode: bad vector length for shape %v: expected=%d got=%d"
|
|
(show xs) (product xs) (S.length v)
|
|
else TensorData (FFI.TensorData xs dt (S.unsafeCast v))
|
|
where
|
|
dt = tensorType (undefined :: a)
|
|
|
|
instance TensorDataType S.Vector Float where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Double where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Int8 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Int16 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Int32 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Int64 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Word8 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
instance TensorDataType S.Vector Word16 where
|
|
decodeTensorData = simpleDecode
|
|
encodeTensorData = simpleEncode
|
|
|
|
-- TODO: Haskell and tensorflow use different byte sizes for bools, which makes
|
|
-- encoding more expensive. It may make sense to define a custom boolean type.
|
|
instance TensorDataType S.Vector Bool where
|
|
decodeTensorData =
|
|
S.convert . S.map (/= 0) . FFI.tensorDataBytes . unTensorData
|
|
encodeTensorData (Shape xs) =
|
|
TensorData . FFI.TensorData xs DT_BOOL . S.map fromBool . S.convert
|
|
where
|
|
fromBool x = if x then 1 else 0 :: Word8
|
|
|
|
instance {-# OVERLAPPABLE #-} (Storable a, TensorDataType S.Vector a, TensorType a)
|
|
=> TensorDataType V.Vector a where
|
|
decodeTensorData = (S.convert :: S.Vector a -> V.Vector a) . decodeTensorData
|
|
encodeTensorData x = encodeTensorData x . (S.convert :: V.Vector a -> S.Vector a)
|
|
|
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Float) where
|
|
decodeTensorData = error "TODO (Complex Float)"
|
|
encodeTensorData = error "TODO (Complex Float)"
|
|
|
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
|
|
decodeTensorData = error "TODO (Complex Double)"
|
|
encodeTensorData = error "TODO (Complex Double)"
|
|
|
|
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
|
|
-- Strings can be encoded in various ways, see [0] for an overview.
|
|
--
|
|
-- The data starts with an array of TF_TString structs (24 bytes each), one
|
|
-- for each element in the tensor. In some cases, the actual string
|
|
-- contents are inlined in the TF_TString, in some cases they are in the
|
|
-- heap, in some cases they are appended to the end of the data.
|
|
--
|
|
-- When decoding, we delegate most of those details to the TString C API.
|
|
-- However, when encoding, the TString C API is prone to memory leaks given
|
|
-- the current design of tensorflow-haskell, so, instead we manually encode
|
|
-- all the strings in the "offset" format, where none of the string data is
|
|
-- stored in separate heap objects and so no destructor hook is necessary.
|
|
--
|
|
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
|
|
decodeTensorData tensorData =
|
|
if S.length bytes < minBytes
|
|
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
|
|
show (S.length bytes) ++ ", need at least " ++ show minBytes
|
|
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
|
|
where
|
|
bytes = FFI.tensorDataBytes $ unTensorData tensorData
|
|
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
|
|
minBytes = Raw.sizeOfTString * numElements
|
|
encodeTensorData (Shape xs) vec =
|
|
TensorData $ FFI.TensorData xs dt byteVector
|
|
where
|
|
dt = tensorType (undefined :: ByteString)
|
|
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
|
|
-- Add a string to an offset table and data blob.
|
|
addString :: (Builder, Builder, Word32, Word32)
|
|
-> ByteString
|
|
-> (Builder, Builder, Word32, Word32)
|
|
addString (table, strings, tableOffset, stringsOffset) str =
|
|
( table <> Builder.word32LE sizeField
|
|
<> Builder.word32LE offsetField
|
|
<> Builder.word32LE capacityField
|
|
<> Builder.word32LE 0
|
|
<> Builder.word32LE 0
|
|
<> Builder.word32LE 0
|
|
, strings <> Builder.byteString str
|
|
, tableOffset + fromIntegral Raw.sizeOfTString
|
|
, stringsOffset + strLen
|
|
)
|
|
where
|
|
strLen :: Word32 = fromIntegral $ B.length str
|
|
-- TF_TString.size includes a union tag in the first two bits.
|
|
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
|
|
-- offset is relative to the start of the TF_TString instance, so
|
|
-- we add the remaining distance to the end of the table to the
|
|
-- offset from the start of the string data.
|
|
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
|
|
capacityField :: Word32 = strLen
|
|
-- Encode all strings.
|
|
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
|
|
-- Concat offset table with data.
|
|
bytes = table' <> strings'
|
|
-- Convert to Vector Word8.
|
|
byteVector = S.fromList $ L.unpack $ Builder.toLazyByteString bytes
|
|
|
|
newtype Scalar a = Scalar {unScalar :: a}
|
|
deriving (Show, Eq, Ord, Num, Fractional, Floating, Real, RealFloat,
|
|
RealFrac, IsString)
|
|
|
|
instance (TensorDataType V.Vector a, TensorType a) => TensorDataType Scalar a where
|
|
decodeTensorData = Scalar . headFromSingleton . decodeTensorData
|
|
encodeTensorData x (Scalar y) = encodeTensorData x (V.fromList [y])
|
|
|
|
headFromSingleton :: V.Vector a -> a
|
|
headFromSingleton x
|
|
| V.length x == 1 = V.head x
|
|
| otherwise = error $
|
|
"Unable to extract singleton from tensor of length "
|
|
++ show (V.length x)
|
|
|
|
|
|
-- | Shape (dimensions) of a tensor.
|
|
--
|
|
-- TensorFlow supports shapes of unknown rank, which are represented as
|
|
-- @Nothing :: Maybe Shape@ in Haskell.
|
|
newtype Shape = Shape [Int64] deriving Show
|
|
|
|
instance IsList Shape where
|
|
type Item Shape = Int64
|
|
fromList = Shape . fromList
|
|
toList (Shape ss) = toList ss
|
|
|
|
protoShape :: Lens' TensorShapeProto Shape
|
|
protoShape = under (adapter protoToShape shapeToProto)
|
|
where
|
|
protoToShape p = fromMaybe (error msg) (view protoMaybeShape p)
|
|
where msg = "Can't convert TensorShapeProto with unknown rank to Shape: "
|
|
++ showMessageShort p
|
|
shapeToProto s' = defMessage & protoMaybeShape .~ Just s'
|
|
|
|
protoMaybeShape :: Lens' TensorShapeProto (Maybe Shape)
|
|
protoMaybeShape = under (adapter protoToShape shapeToProto)
|
|
where
|
|
protoToShape :: TensorShapeProto -> Maybe Shape
|
|
protoToShape p =
|
|
if view unknownRank p
|
|
then Nothing
|
|
else Just (Shape (p ^.. dim . traverse . size))
|
|
shapeToProto :: Maybe Shape -> TensorShapeProto
|
|
shapeToProto Nothing =
|
|
defMessage & unknownRank .~ True
|
|
shapeToProto (Just (Shape ds)) =
|
|
defMessage & dim .~ fmap (\d -> defMessage & size .~ d) ds
|
|
|
|
|
|
class Attribute a where
|
|
attrLens :: Lens' AttrValue a
|
|
|
|
instance Attribute Float where
|
|
attrLens = f
|
|
|
|
instance Attribute ByteString where
|
|
attrLens = s
|
|
|
|
instance Attribute Int64 where
|
|
attrLens = i
|
|
|
|
instance Attribute DataType where
|
|
attrLens = type'
|
|
|
|
instance Attribute TensorProto where
|
|
attrLens = tensor
|
|
|
|
instance Attribute Bool where
|
|
attrLens = b
|
|
|
|
instance Attribute Shape where
|
|
attrLens = shape . protoShape
|
|
|
|
instance Attribute (Maybe Shape) where
|
|
attrLens = shape . protoMaybeShape
|
|
|
|
-- TODO(gnezdo): support generating list(Foo) from [Foo].
|
|
instance Attribute AttrValue'ListValue where
|
|
attrLens = list
|
|
|
|
instance Attribute [DataType] where
|
|
attrLens = list . type'
|
|
|
|
instance Attribute [Int64] where
|
|
attrLens = list . i
|
|
|
|
-- | A heterogeneous list type.
|
|
data ListOf f as where
|
|
Nil :: ListOf f '[]
|
|
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
|
|
|
|
infixr 5 :/
|
|
|
|
type family All f as :: Constraint where
|
|
All f '[] = ()
|
|
All f (a ': as) = (f a, All f as)
|
|
|
|
type family Map f as where
|
|
Map f '[] = '[]
|
|
Map f (a ': as) = f a ': Map f as
|
|
|
|
instance All Eq (Map f as) => Eq (ListOf f as) where
|
|
Nil == Nil = True
|
|
(x :/ xs) == (y :/ ys) = x == y && xs == ys
|
|
-- Newer versions of GHC use the GADT to tell that the previous cases are
|
|
-- exhaustive.
|
|
#if __GLASGOW_HASKELL__ < 800
|
|
_ == _ = False
|
|
#endif
|
|
|
|
instance All Show (Map f as) => Show (ListOf f as) where
|
|
showsPrec _ Nil = showString "Nil"
|
|
showsPrec d (x :/ xs) = showParen (d > 10)
|
|
$ showsPrec 6 x . showString " :/ "
|
|
. showsPrec 6 xs
|
|
|
|
type List = ListOf Identity
|
|
|
|
-- | Equivalent of ':/' for lists.
|
|
(/:/) :: a -> List as -> List (a ': as)
|
|
(/:/) = (:/) . Identity
|
|
|
|
infixr 5 /:/
|
|
|
|
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
|
|
--
|
|
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
|
|
-- natural representation as a conjunction, i.e.,
|
|
--
|
|
-- @
|
|
-- a == Double || a == Float
|
|
-- @
|
|
--
|
|
-- into a disjunction like
|
|
--
|
|
-- @
|
|
-- a \/= Int32 && a \/= Int64 && a \/= ByteString && ...
|
|
-- @
|
|
--
|
|
-- using an enumeration of all the possible 'TensorType's.
|
|
type OneOf ts a
|
|
-- Assert `TensorTypes' ts` to make error messages a little better.
|
|
= (TensorType a, TensorTypes' ts, NoneOf (AllTensorTypes \\ ts) a)
|
|
|
|
type OneOfs ts as = (TensorTypes as, TensorTypes' ts,
|
|
NoneOfs (AllTensorTypes \\ ts) as)
|
|
|
|
type family NoneOfs ts as :: Constraint where
|
|
NoneOfs ts '[] = ()
|
|
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
|
|
|
|
data TensorTypeProxy a where
|
|
TensorTypeProxy :: TensorType a => TensorTypeProxy a
|
|
|
|
type TensorTypeList = ListOf TensorTypeProxy
|
|
|
|
fromTensorTypeList :: TensorTypeList ts -> [DataType]
|
|
fromTensorTypeList Nil = []
|
|
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
|
|
= tensorType (undefined :: t) : fromTensorTypeList ts
|
|
|
|
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
|
|
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
|
|
|
|
class TensorTypes (ts :: [*]) where
|
|
tensorTypes :: TensorTypeList ts
|
|
|
|
instance TensorTypes '[] where
|
|
tensorTypes = Nil
|
|
|
|
-- | A constraint that the input is a list of 'TensorTypes'.
|
|
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
|
|
tensorTypes = TensorTypeProxy :/ tensorTypes
|
|
|
|
-- | A simpler version of the 'TensorTypes' class, that doesn't run
|
|
-- afoul of @-Wsimplifiable-class-constraints@.
|
|
--
|
|
-- In more detail: the constraint @OneOf '[Double, Float] a@ leads
|
|
-- to the constraint @TensorTypes' '[Double, Float]@, as a safety-check
|
|
-- to give better error messages. However, if @TensorTypes'@ were a class,
|
|
-- then GHC 8.2.1 would complain with the above warning unless @NoMonoBinds@
|
|
-- were enabled. So instead, we use a separate type family for this purpose.
|
|
-- For more details: https://ghc.haskell.org/trac/ghc/ticket/11948
|
|
type family TensorTypes' (ts :: [*]) :: Constraint where
|
|
-- Specialize this type family when `ts` is a long list, to avoid deeply
|
|
-- nested tuples of constraints. Works around a bug in ghc-8.0:
|
|
-- https://ghc.haskell.org/trac/ghc/ticket/12175
|
|
TensorTypes' (t1 ': t2 ': t3 ': t4 ': ts)
|
|
= (TensorType t1, TensorType t2, TensorType t3, TensorType t4
|
|
, TensorTypes' ts)
|
|
TensorTypes' (t1 ': t2 ': t3 ': ts)
|
|
= (TensorType t1, TensorType t2, TensorType t3, TensorTypes' ts)
|
|
TensorTypes' (t1 ': t2 ': ts)
|
|
= (TensorType t1, TensorType t2, TensorTypes' ts)
|
|
TensorTypes' (t ': ts) = (TensorType t, TensorTypes' ts)
|
|
TensorTypes' '[] = ()
|
|
|
|
-- | A constraint checking that two types are different.
|
|
type family a /= b :: Constraint where
|
|
a /= a = TypeError a ~ ExcludedCase
|
|
a /= b = ()
|
|
|
|
-- | Helper types to produce a reasonable type error message when the Constraint
|
|
-- "a /= a" fails.
|
|
-- TODO(judahjacobson): Use ghc-8's CustomTypeErrors for this.
|
|
data TypeError a
|
|
data ExcludedCase
|
|
|
|
-- | An enumeration of all valid 'TensorType's.
|
|
type AllTensorTypes =
|
|
-- NOTE: This list should be kept in sync with
|
|
-- TensorFlow.OpGen.dtTypeToHaskell.
|
|
-- TODO: Add support for Complex Float/Double.
|
|
'[ Float
|
|
, Double
|
|
, Int8
|
|
, Int16
|
|
, Int32
|
|
, Int64
|
|
, Word8
|
|
, Word16
|
|
, ByteString
|
|
, Bool
|
|
]
|
|
|
|
-- | Removes a type from the given list of types.
|
|
type family Delete a as where
|
|
Delete a '[] = '[]
|
|
Delete a (a ': as) = Delete a as
|
|
Delete a (b ': as) = b ': Delete a as
|
|
|
|
-- | Takes the difference of two lists of types.
|
|
type family as \\ bs where
|
|
as \\ '[] = as
|
|
as \\ (b ': bs) = Delete b as \\ bs
|
|
|
|
-- | A constraint that the type @a@ doesn't appear in the type list @ts@.
|
|
-- Assumes that @a@ and each of the elements of @ts@ are 'TensorType's.
|
|
type family NoneOf ts a :: Constraint where
|
|
-- Specialize this type family when `ts` is a long list, to avoid deeply
|
|
-- nested tuples of constraints. Works around a bug in ghc-8.0:
|
|
-- https://ghc.haskell.org/trac/ghc/ticket/12175
|
|
NoneOf (t1 ': t2 ': t3 ': t4 ': ts) a
|
|
= (a /= t1, a /= t2, a /= t3, a /= t4, NoneOf ts a)
|
|
NoneOf (t1 ': t2 ': t3 ': ts) a = (a /= t1, a /= t2, a /= t3, NoneOf ts a)
|
|
NoneOf (t1 ': t2 ': ts) a = (a /= t1, a /= t2, NoneOf ts a)
|
|
NoneOf (t1 ': ts) a = (a /= t1, NoneOf ts a)
|
|
NoneOf '[] a = ()
|