1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-26 21:09:44 +01:00

Support lists of tensors in ops. (#79)

Adds a new type `ListOf` which wraps a heterogeneous list; for example,
`ListOf (Tensor Value) '[Int32, Float]` represents a list of two
elements: a tensor of int32s and a tensor of floats.

Also changes the `Queue2` type (which suppored pairs of tensors) to
`Queue` (which supports arbitrary lists).
This commit is contained in:
Judah Jacobson 2017-03-17 13:53:19 -07:00 committed by Greg Steuck
parent 7cc6a69866
commit 9209dfc4c4
9 changed files with 273 additions and 144 deletions

View file

@ -64,43 +64,9 @@ generatingOpsWrappers = hooks
(prettyLazyText 80 $ docOpList flags x)
blackList =
-- A few data flow ops take a list of heterogeneous
-- parameters which we don't support in general form.
[ "HashTable"
, "MutableDenseHashTable"
, "MutableHashTable"
, "MutableHashTableOfTensors"
, "QueueDequeue"
, "QueueDequeueMany"
, "QueueDequeueUpTo"
, "Stack"
, "TensorArray"
, "TensorArrayV2"
, "QueueEnqueueManyV2"
, "QueueDequeueV2"
, "QueueDequeueUpToV2"
, "QueueEnqueueV2"
, "QueueDequeueManyV2"
, "Stage"
, "Unstage"
-- These should be possible to support by adding a bunch of
-- overloads with a variable number of tuple arguments.
, "Assert"
, "BarrierTakeMany"
, "Print"
, "QueueEnqueue"
, "QueueEnqueueMany"
-- Need list of types support.
, "DecodeCSV"
, "ParseExample"
, "ParseSingleSequenceExample"
, "RestoreV2"
, "Save"
, "SaveV2"
, "SaveSlices"
, "SymbolicGradient"
, "_ArrayToList"
, "_ListToArray"
[ -- Requires the "func" type:
"SymbolicGradient"
-- Easy: support larger result tuples.
, "ParseSingleSequenceExample"
, "Skipgram"
]

View file

@ -147,6 +147,7 @@ imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
, "import Data.Int (Int8, Int16, Int32, Int64)"
, "import Data.Proxy (Proxy(Proxy))"
, "import Data.Word (Word8, Word16)"
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
@ -210,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
defineLengthAttr a = renderHaskellAttrName a <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
@ -229,9 +233,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
<- parsedOutputs pOp]
buildOpParts =
"opDef" <+> renderQuotedTFName (parsedOpName pOp) :
-- Renders tensor arguments.
[ "& opAttr" <+> renderQuotedTFName n <+>
".~ tensorType (undefined ::" <+> renderHaskellName n <> ")"
-- Renders type parameter arguments.
[ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a
| a <- inferredTypeAttrs pOp, let n = attrName a
] ++
-- Renders mandatory attributes as function parameters.
@ -244,6 +247,12 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
]
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
<> ")"
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
<> ")"
-- | Write a comment with the inputs/outputs/attributes in proto format, for
-- debugging.
@ -272,8 +281,8 @@ typeSig pOp = constraints
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ concatMap tensorArgConstraint
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ map tensorArgConstraint
$ inferredTypeAttrs pOp
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
@ -305,17 +314,18 @@ tensorArg p = case parsedArgCase p of
ResourceArg -> "ResourceHandle"
SimpleArg { argType = t, argCaseKind = k } -> tensorType t k
ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k
MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}"
MixedListArg {argTypeAttr = t, argCaseKind = k}
-> "TensorList" <+> kind k <+> renderHaskellName t
where
kind k = case k of
ArgTensorRef -> "Ref"
ArgTensorValue -> "Value"
ArgTensorEither v' -> strictText v'
tensorType t k = let
v = case k of
ArgTensorRef -> "Tensor Ref"
ArgTensorValue -> "Tensor Value"
ArgTensorEither v' -> "Tensor" <+> strictText v'
a = case t of
ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt
ArgTypeAttr n -> renderHaskellName n
in v <+> a
in "Tensor" <+> kind k <+> a
attrComment :: Attr a -> Doc
attrComment a = argComment' (attrName a) (attrDescription a)
@ -347,18 +357,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os
]
-- | Constraints for a given type parameter.
-- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"]
tensorArgConstraint :: Attr [DataType] -> [Doc]
tensorArgConstraint a
= ("TensorType" <+> n
: if null typeList
then []
else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n])
-- E.g.: "TensorType t" or "OneOf [Int64, Float] t"
-- or "TensorTypes ts" or "OneOfs [..] ts".
tensorArgConstraint :: Attr TypeParam -> Doc
tensorArgConstraint a = case attrInfo a of
TypeParam False Nothing -> "TensorType" <+> n
TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n
TypeParam True Nothing -> "TensorTypes" <+> n
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
where
n = renderHaskellName $ attrName a
typeList = map strictText $
Set.toList $ Set.fromList $
map dtTypeToHaskell $ attrInfo a
n = renderHaskellAttrName a
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
typeList = ("'" <>) . brackets . commasep . map strictText .
Set.toList . Set.fromList .
map dtTypeToHaskell . toList
-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.

View file

@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp
, Attr(..)
, AttrType(..)
, AttrBaseType(..)
, TypeParam(..)
, ParsedArg(..)
, ParsedArgCase(..)
, ArgType(..)
@ -62,10 +63,8 @@ data ParsedOp = ParsedOp
, explicitInputAttrs :: [Attr AttrType]
-- ^ Attributes that must be set explicitly when creating the op.
-- Associated with the type of the attribute.
, inferredTypeAttrs :: [Attr [DataType]]
, inferredTypeAttrs :: [Attr TypeParam]
-- ^ Attributes that are type parameters.
-- Associated with the list of allowed types (see: TensorFlow.Types.OneOf).
-- If this list is empty, then any type is acceptable.
, inferredListSizeAttrs :: [Attr (NonEmpty Name)]
-- Attributes which are list sizes (ints) that are inferred automatically
-- from one or more of the input tensors.
@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
| AttrType | AttrShape | AttrTensor
deriving Eq
data TypeParam = TypeParam
{ typeParamIsList :: Bool
, typeParamRestrictions :: Maybe (NonEmpty DataType)
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
-- If 'Nothing', then any type is acceptable.
}
-- | An input or output argument (Tensor) for an op.
data ParsedArg = ParsedArg
{ parsedArgName :: Name
@ -120,7 +126,6 @@ data ParsedArgCase
}
| MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind }
-- ^ A heterogeneous list.
-- TODO(judahjacobson): Implement this.
| ResourceArg
argKind :: ParsedArgCase -> Maybe ArgKind
@ -223,11 +228,6 @@ parseOp o = ParsedOp
(o ^. inputArg) tensorKindParams
tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]]
parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg)
-- Type attributes that can be inferred from at least one input or output.
argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr
$ parsedInputs ++ parsedOutputs
inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName)
$ mapMaybeAttrs getInferredTypeAttr $ o ^. attr
-- Integer attributes that can be inferred from the size of at least one
-- input list.
inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs)
@ -235,10 +235,14 @@ parseOp o = ParsedOp
implicitAttrs = Set.fromList $ map tfName $
map attrName inferredTypeAttrs
++ map attrName inferredListSizeAttrs
-- Attributes that can't be inferred and don't have defaults, so must be passed
-- as separate arguments to the op.
inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr
argTypeParams = Set.fromList $ map tfName $
mapMaybe (getArgTypeParam . parsedArgCase) $
parsedInputs ++ parsedOutputs
-- Attributes that can't be inferred and don't have defaults, so must be
-- passed as separate arguments to the op.
explicitInputAttrs = sortBy (comparing (tfName . attrName))
$ mapMaybeAttrs (getExplicitInputAttr implicitAttrs)
$ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs)
$ o ^. attr
-- TODO(judahjacobson): Some arguments should be refs.
@ -252,29 +256,30 @@ outputTensorKind a
| a ^. isRef = ArgTensorRef
| otherwise = ArgTensorValue
getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr implicitAttrs a
getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType
getExplicitInputAttr o implicitAttrs a
| TFName (a ^. name) `Set.notMember` implicitAttrs
, a ^. maybe'defaultValue == Nothing
, t <- parseAttrType (a ^. type')
, t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t
, t <- parseAttrType o (a ^. type')
, t `elem` map AttrSingle
[AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape]
++ [AttrList AttrType] = Just t
| otherwise = Nothing
-- | The type attribute used by this input or output (if any).
parsedArgTypeAttr :: ParsedArg -> Maybe TFName
parsedArgTypeAttr p = case parsedArgCase p of
ResourceArg -> Nothing
SimpleArg {argType = t} -> fromArgType t
ListArg {argType = t} -> fromArgType t
MixedListArg {argTypeAttr = n} -> Just $ tfName n
getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam
getInferredTypeAttr argTypeParams a
| TFName (a ^. name) `notElem` argTypeParams = Nothing
| a ^. type' == "type" = Just $ TypeParam False allowed
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
| otherwise = Nothing
where
fromArgType (ArgTypeAttr n) = Just $ tfName n
fromArgType _ = Nothing
allowed = nonEmpty (a ^. allowedValues . list . type')
getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType]
getInferredTypeAttr a
| a ^. type' == "type" = Just $ a ^. allowedValues . list . type'
| otherwise = Nothing
getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n
getArgTypeParam MixedListArg { argTypeAttr = n } = Just n
getArgTypeParam _ = Nothing
getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name)
getInferredListSizeAttr inputs a
@ -317,8 +322,8 @@ parseArgCase a tKind
maybeAttr "" = Nothing
maybeAttr t = Just $ makeName t
parseAttrType :: Text -> AttrType
parseAttrType = \case
parseAttrType :: OpDef -> Text -> AttrType
parseAttrType o = \case
"string" -> AttrSingle AttrBytes
"int" -> AttrSingle AttrInt64
"float" -> AttrSingle AttrFloat
@ -334,3 +339,4 @@ parseAttrType = \case
"list(shape)" -> AttrList AttrShape
"list(tensor)" -> AttrList AttrTensor
t -> error $ "parseAttrType: unrecognized type " ++ show t
++ " for op " ++ show (o ^. name)

View file

@ -12,67 +12,65 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Queues in TensorFlow graph. Very limited support for now.
module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where
module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Proxy (Proxy(..))
import Lens.Family2 ((.~), (&))
import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef)
import TensorFlow.BuildOp (buildOp)
import TensorFlow.ControlFlow (group)
import TensorFlow.Tensor (Ref, Tensor)
import TensorFlow.Types (TensorType, tensorType)
import TensorFlow.Tensor (Ref, Tensor, TensorList)
import TensorFlow.Types (TensorTypes, fromTensorTypes)
-- | A queue carrying tuples. The underlying structure is more
-- versatile and can be made to support arbitrary tuples.
data Queue2 a b = Queue2 { handle :: Handle }
-- | A queue carrying tuples.
data Queue (as :: [*]) = Queue { handle :: Handle }
type Handle = Tensor Ref ByteString
-- | Adds the given values to the queue.
enqueue :: forall a b v1 v2. (TensorType a, TensorType b)
=> Queue2 a b
-> Tensor v1 a
-> Tensor v2 b
enqueue :: forall as v . TensorTypes as
=> Queue as
-> TensorList v as
-> Build ControlNode
enqueue q =
buildOp (opDef "QueueEnqueue"
& opAttr "Tcomponents" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
& opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
-- | Retrieves the values from the queue.
dequeue :: forall a b . (TensorType a, TensorType b)
=> Queue2 a b
-> Build (Tensor Ref a, Tensor Ref b)
-- ^ Dequeued tensors. They are paired in a sense
dequeue :: forall as . TensorTypes as
=> Queue as
-> Build (TensorList Ref as)
-- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are
-- not consumed together.
dequeue q =
buildOp (opDef "QueueDequeue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)])
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as))
(handle q)
-- | Creates a new queue with the given capacity and shared name.
makeQueue2 :: forall a b . (TensorType a, TensorType b)
makeQueue :: forall as . TensorTypes as
=> Int64 -- ^ The upper bound on the number of elements in
-- this queue. Negative numbers mean no limit.
-> ByteString -- ^ If non-empty, this queue will be shared
-- under the given name across multiple sessions.
-> Build (Queue2 a b)
makeQueue2 capacity sharedName = do
-> Build (Queue as)
makeQueue capacity sharedName = do
q <- buildOp (opDef "FIFOQueue"
& opAttr "component_types" .~ [ tensorType (undefined :: a)
, tensorType (undefined :: b)]
& opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)
& opAttr "shared_name" .~ sharedName
& opAttr "capacity" .~ capacity
)
group q >>= addInitializer
return (Queue2 q)
return (Queue q)
-- TODO(gnezdo): Figure out the closing story for queues.

View file

@ -12,6 +12,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
@ -20,7 +21,7 @@ module Main where
import Control.Monad.IO.Class (liftIO)
import Data.Int (Int64)
import Google.Test (googleTest)
import TensorFlow.Types (Scalar(..))
import TensorFlow.Types (ListOf(..), Scalar(..), (/:/))
import TensorFlow.Ops (scalar)
import TensorFlow.Queue
import TensorFlow.Session
@ -39,42 +40,50 @@ import qualified Data.ByteString as BS
-- | Test basic queue behaviors.
testBasic :: Test
testBasic = testCase "testBasic" $ runSession $ do
(q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 ""
buildAnd run_ (enqueue q 42 (scalar "Hi"))
q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 ""
buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil
x <- buildAnd run (dequeue q)
liftIO $ (Scalar 42, Scalar "Hi") @=? x
liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x
buildAnd run_ (enqueue q 56 (scalar "Bar"))
buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil
y <- buildAnd run (dequeue q)
liftIO $ (Scalar 56, Scalar "Bar") @=? y
-- Note: we use explicit "Scalar" here to specify the type that was
-- fetched. Equivalently we could write
-- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString]
-- or else allow the types to be determined by future use of the fetched
-- value.
let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil
liftIO $ expected @=? y
-- | Test queue pumping.
testPump :: Test
testPump = testCase "testPump" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue"
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue"
(,) <$> dequeue q
<*> enqueue q 31 (scalar "Baz")
<*> enqueue q (31 :/ scalar "Baz" :/ Nil)
-- This is a realistic use. The pump inputs are pre-bound to some
-- nodes that produce values when pumped (e.g. read from a
-- file).
run_ (pump, pump)
(x, y) <- run (deq, deq)
liftIO $ (Scalar 31, Scalar "Baz") @=? x
liftIO $ (Scalar 31, Scalar "Baz") @=? y
let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil
liftIO $ expected @=? x
liftIO $ expected @=? y
testAsync :: Test
testAsync = testCase "testAsync" $ runSession $ do
(deq, pump) <- build $ do
q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 ""
q :: Queue [Int64, BS.ByteString] <- makeQueue 2 ""
(,) <$> dequeue q
<*> enqueue q 10 (scalar "Async")
<*> enqueue q (10 :/ scalar "Async" :/ Nil)
-- Pumps the queue until canceled by runSession exiting.
asyncProdNodes pump
-- Picks up a couple values and verifies they are as expected.
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?)
let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil
run deq >>= liftIO . (expected @=?)
run deq >>= liftIO . (expected @=?)
main :: IO ()
main = googleTest [ testBasic

View file

@ -13,6 +13,8 @@
-- limitations under the License.
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
@ -33,6 +35,7 @@ import Lens.Family2 ((&), (<>~), (^.))
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show
@ -98,6 +101,22 @@ instance OpResult (Tensor Ref a) where
instance OpResult ControlNode where
toResult = ControlNode <$> ask
tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as)
tensorListResult v = loop (tensorTypes :: TensorTypeList as)
where
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- tensorResult v
ts <- loop ls
return (t :/ ts)
instance TensorTypes as => OpResult (TensorList Value as) where
toResult = tensorListResult ValueKind
instance TensorTypes as => OpResult (TensorList Ref as) where
toResult = tensorListResult RefKind
instance OpResult a => OpResult [a] where
toResult = do
ResultState i ns <- get
@ -159,6 +178,12 @@ instance BuildOp (Tensor Value a) where
instance BuildOp (Tensor Ref a) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Value as) where
buildOp' = pureResult
instance TensorTypes as => BuildOp (TensorList Ref as) where
buildOp' = pureResult
instance BuildOp [Tensor Value a] where
buildOp' = pureResult
@ -199,6 +224,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum)
instance BuildOp f => BuildOp (TensorList v as -> f) where
buildOp' rf o accum ts
= buildOp' rf o (reverse (tensorListOutputs ts) ++ accum)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool

View file

@ -12,15 +12,18 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Nodes where
import Control.Applicative (liftA2, liftA3)
import Data.Functor.Identity (Identity)
import Data.Map.Strict (Map)
import Data.Monoid ((<>))
import Data.Set (Set)
@ -96,6 +99,19 @@ instance Nodes ControlNode where
instance a ~ () => Fetchable ControlNode a where
getFetch _ = return $ pure ()
instance Nodes (ListOf f '[]) where
getNodes _ = return Set.empty
instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where
getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs)
instance l ~ List '[] => Fetchable (ListOf f '[]) l where
getFetch _ = return $ pure Nil
instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity)
=> Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where
getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs
instance Nodes (Tensor v a) where
getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp)

View file

@ -12,20 +12,30 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module TensorFlow.Tensor where
import Data.String (IsString(..))
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal')
import Lens.Family2 (Lens', Traversal', (^.))
import Lens.Family2.Unchecked (lens)
import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr)
import TensorFlow.Types (TensorData(..), Attribute)
import TensorFlow.Types
( TensorData(..)
, Attribute
, ListOf(..)
)
import qualified TensorFlow.Internal.FFI as FFI
-- | A named output of a TensorFlow operation.
@ -83,3 +93,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td
-- TODO(judahjacobson): add more safety checks here.
tensorFromName :: TensorKind v -> Text.Text -> Tensor v a
tensorFromName v = Tensor v . fromString . Text.unpack
type TensorList v = ListOf (Tensor v)
tensorListOutputs :: TensorList v as -> [Output]
tensorListOutputs Nil = []
tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts

View file

@ -13,9 +13,11 @@
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
@ -36,23 +38,35 @@ module TensorFlow.Types
, Shape(..)
, protoShape
, Attribute(..)
, DataType(..)
-- * Lists
, ListOf(..)
, List
, (/:/)
, TensorTypeProxy(..)
, TensorTypes(..)
, TensorTypeList
, fromTensorTypeList
, fromTensorTypes
-- * Type constraints
, OneOf
, type (/=)
, OneOfs
-- ** Implementation of constraints
, TypeError
, ExcludedCase
, TensorTypes
, NoneOf
, type (\\)
, Delete
, AllTensorTypes
) where
import Data.Functor.Identity (Identity(..))
import Data.Complex (Complex)
import Data.Default (def)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Monoid ((<>))
import Data.Proxy (Proxy(..))
import Data.String (IsString)
import Data.Word (Word8, Word16, Word64)
import Foreign.Storable (Storable)
@ -376,6 +390,44 @@ instance Attribute [DataType] where
instance Attribute [Int64] where
attrLens = list . i
-- | A heterogeneous list type.
data ListOf f as where
Nil :: ListOf f '[]
(:/) :: f a -> ListOf f as -> ListOf f (a ': as)
infixr 5 :/
type family All f as :: Constraint where
All f '[] = ()
All f (a ': as) = (f a, All f as)
type family Map f as where
Map f '[] = '[]
Map f (a ': as) = f a ': Map f as
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == Nil = True
(x :/ xs) == (y :/ ys) = x == y && xs == ys
-- Newer versions of GHC use the GADT to tell that the previous cases are
-- exhaustive.
#if _GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec _ Nil = showString "Nil"
showsPrec d (x :/ xs) = showParen (d > 10)
$ showsPrec 6 x . showString " :/ "
. showsPrec 6 xs
type List = ListOf Identity
-- | Equivalent of ':/' for lists.
(/:/) :: a -> List as -> List (a ': as)
(/:/) = (:/) . Identity
infixr 5 /:/
-- | A 'Constraint' specifying the possible choices of a 'TensorType'.
--
-- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the
@ -393,13 +445,38 @@ instance Attribute [Int64] where
--
-- using an enumeration of all the possible 'TensorType's.
type OneOf ts a
-- Assert `TensorTypes ts` to make error messages a little better.
= (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a)
-- | A check that the input is a list of 'TensorType's.
-- Helps improve error messages when using 'OneOf'.
type OneOfs ts as = (TensorTypes as, TensorTypes ts,
NoneOfs (AllTensorTypes \\ ts) as)
type family NoneOfs ts as :: Constraint where
NoneOfs ts '[] = ()
NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as)
data TensorTypeProxy a where
TensorTypeProxy :: TensorType a => TensorTypeProxy a
type TensorTypeList = ListOf TensorTypeProxy
fromTensorTypeList :: TensorTypeList ts -> [DataType]
fromTensorTypeList Nil = []
fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts)
= tensorType (undefined :: t) : fromTensorTypeList ts
fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType]
fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as)
class TensorTypes (ts :: [*]) where
instance TensorTypes '[]
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts)
tensorTypes :: TensorTypeList ts
instance TensorTypes '[] where
tensorTypes = Nil
-- | A constraint that the input is a list of 'TensorTypes'.
instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where
tensorTypes = TensorTypeProxy :/ tensorTypes
-- | A constraint checking that two types are different.
type family a /= b :: Constraint where