diff --git a/tensorflow-core-ops/Setup.hs b/tensorflow-core-ops/Setup.hs index 349ae11..79eba2a 100644 --- a/tensorflow-core-ops/Setup.hs +++ b/tensorflow-core-ops/Setup.hs @@ -64,43 +64,9 @@ generatingOpsWrappers = hooks (prettyLazyText 80 $ docOpList flags x) blackList = - -- A few data flow ops take a list of heterogeneous - -- parameters which we don't support in general form. - [ "HashTable" - , "MutableDenseHashTable" - , "MutableHashTable" - , "MutableHashTableOfTensors" - , "QueueDequeue" - , "QueueDequeueMany" - , "QueueDequeueUpTo" - , "Stack" - , "TensorArray" - , "TensorArrayV2" - , "QueueEnqueueManyV2" - , "QueueDequeueV2" - , "QueueDequeueUpToV2" - , "QueueEnqueueV2" - , "QueueDequeueManyV2" - , "Stage" - , "Unstage" - -- These should be possible to support by adding a bunch of - -- overloads with a variable number of tuple arguments. - , "Assert" - , "BarrierTakeMany" - , "Print" - , "QueueEnqueue" - , "QueueEnqueueMany" - -- Need list of types support. - , "DecodeCSV" - , "ParseExample" - , "ParseSingleSequenceExample" - , "RestoreV2" - , "Save" - , "SaveV2" - , "SaveSlices" - , "SymbolicGradient" - , "_ArrayToList" - , "_ListToArray" + [ -- Requires the "func" type: + "SymbolicGradient" -- Easy: support larger result tuples. + , "ParseSingleSequenceExample" , "Skipgram" ] diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index 8e2a6f7..c8fa43b 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -147,6 +147,7 @@ imports = stack [ "import Data.ByteString (ByteString)" , "import Data.Complex (Complex)" , "import Data.Int (Int8, Int16, Int32, Int64)" + , "import Data.Proxy (Proxy(Proxy))" , "import Data.Word (Word8, Word16)" , "import Lens.Family2 ((.~), (&))" , "import TensorFlow.Build" @@ -210,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc] whereClause [] = [] whereClause as = [indent 2 $ "where" indent 2 (stack $ map defineLengthAttr as)] where - defineLengthAttr a = renderHaskellName (attrName a) <+> "=" + defineLengthAttr a = renderHaskellAttrName a <+> "=" <+> "fromIntegral (length" <+> renderHaskellName (NE.head $ attrInfo a) <> ") :: Int64" +renderHaskellAttrName :: Attr a -> Doc +renderHaskellAttrName = renderHaskellName . attrName + functionBody :: ParsedOp -> Doc functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) indent indentation (sep tensorArgs) @@ -229,9 +233,8 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) <- parsedOutputs pOp] buildOpParts = "opDef" <+> renderQuotedTFName (parsedOpName pOp) : - -- Renders tensor arguments. - [ "& opAttr" <+> renderQuotedTFName n <+> - ".~ tensorType (undefined ::" <+> renderHaskellName n <> ")" + -- Renders type parameter arguments. + [ "& opAttr" <+> renderQuotedTFName n <+> ".~" <+> inferredTypeExpr a | a <- inferredTypeAttrs pOp, let n = attrName a ] ++ -- Renders mandatory attributes as function parameters. @@ -244,6 +247,12 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts)) ] tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp + inferredTypeExpr a + | typeParamIsList $ attrInfo a + = "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a + <> ")" + | otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a + <> ")" -- | Write a comment with the inputs/outputs/attributes in proto format, for -- debugging. @@ -272,8 +281,8 @@ typeSig pOp = constraints | otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>" typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp, Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]] - ++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp] - classConstraints = tuple $ concatMap tensorArgConstraint + ++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp] + classConstraints = tuple $ map tensorArgConstraint $ inferredTypeAttrs pOp signatureFold = folddoc (\x y -> x "->" <+> y) attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a) @@ -305,17 +314,18 @@ tensorArg p = case parsedArgCase p of ResourceArg -> "ResourceHandle" SimpleArg { argType = t, argCaseKind = k } -> tensorType t k ListArg { argType = t, argCaseKind = k } -> brackets $ tensorType t k - MixedListArg {} -> "{{{tensorArg: can't handle heterogeneous lists}}}" + MixedListArg {argTypeAttr = t, argCaseKind = k} + -> "TensorList" <+> kind k <+> renderHaskellName t where + kind k = case k of + ArgTensorRef -> "Ref" + ArgTensorValue -> "Value" + ArgTensorEither v' -> strictText v' tensorType t k = let - v = case k of - ArgTensorRef -> "Tensor Ref" - ArgTensorValue -> "Tensor Value" - ArgTensorEither v' -> "Tensor" <+> strictText v' a = case t of ArgTypeFixed dt -> strictText $ dtTypeToHaskell dt ArgTypeAttr n -> renderHaskellName n - in v <+> a + in "Tensor" <+> kind k <+> a attrComment :: Attr a -> Doc attrComment a = argComment' (attrName a) (attrDescription a) @@ -347,18 +357,20 @@ resultComment os = stack $ flatten commentSummary : map commentDetails os ] -- | Constraints for a given type parameter. --- E.g.: ["TensorType t"] or ["TensorType t", "OneOf [Int64, Float] t"] -tensorArgConstraint :: Attr [DataType] -> [Doc] -tensorArgConstraint a - = ("TensorType" <+> n - : if null typeList - then [] - else ["OneOf" <+> "'" <> brackets (commasep typeList) <+> n]) +-- E.g.: "TensorType t" or "OneOf [Int64, Float] t" +-- or "TensorTypes ts" or "OneOfs [..] ts". +tensorArgConstraint :: Attr TypeParam -> Doc +tensorArgConstraint a = case attrInfo a of + TypeParam False Nothing -> "TensorType" <+> n + TypeParam False (Just as) -> "OneOf" <+> typeList as <+> n + TypeParam True Nothing -> "TensorTypes" <+> n + TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n where - n = renderHaskellName $ attrName a - typeList = map strictText $ - Set.toList $ Set.fromList $ - map dtTypeToHaskell $ attrInfo a + n = renderHaskellAttrName a + -- Produces a type-level list, e.g.: '[Int32,Int64,Float] + typeList = ("'" <>) . brackets . commasep . map strictText . + Set.toList . Set.fromList . + map dtTypeToHaskell . toList -- NOTE: The cases of this function should be kept in sync with -- TensorFlow.Types.AllTensorTypes. diff --git a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs index 49d5c34..c7a5b69 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen/ParsedOp.hs @@ -12,6 +12,7 @@ module TensorFlow.OpGen.ParsedOp , Attr(..) , AttrType(..) , AttrBaseType(..) + , TypeParam(..) , ParsedArg(..) , ParsedArgCase(..) , ArgType(..) @@ -62,10 +63,8 @@ data ParsedOp = ParsedOp , explicitInputAttrs :: [Attr AttrType] -- ^ Attributes that must be set explicitly when creating the op. -- Associated with the type of the attribute. - , inferredTypeAttrs :: [Attr [DataType]] + , inferredTypeAttrs :: [Attr TypeParam] -- ^ Attributes that are type parameters. - -- Associated with the list of allowed types (see: TensorFlow.Types.OneOf). - -- If this list is empty, then any type is acceptable. , inferredListSizeAttrs :: [Attr (NonEmpty Name)] -- Attributes which are list sizes (ints) that are inferred automatically -- from one or more of the input tensors. @@ -104,6 +103,13 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool | AttrType | AttrShape | AttrTensor deriving Eq +data TypeParam = TypeParam + { typeParamIsList :: Bool + , typeParamRestrictions :: Maybe (NonEmpty DataType) + -- ^ The list of allowed types (see: TensorFlow.Types.OneOf). + -- If 'Nothing', then any type is acceptable. + } + -- | An input or output argument (Tensor) for an op. data ParsedArg = ParsedArg { parsedArgName :: Name @@ -120,7 +126,6 @@ data ParsedArgCase } | MixedListArg { argTypeAttr :: Name, argCaseKind :: ArgKind } -- ^ A heterogeneous list. - -- TODO(judahjacobson): Implement this. | ResourceArg argKind :: ParsedArgCase -> Maybe ArgKind @@ -223,11 +228,6 @@ parseOp o = ParsedOp (o ^. inputArg) tensorKindParams tensorKindParams = ["v" <> Text.pack (show x) | x <- [1::Integer ..]] parsedOutputs = map (\a -> parseArg a (outputTensorKind a)) (o ^. outputArg) - -- Type attributes that can be inferred from at least one input or output. - argTypeAttrs = Set.fromList $ mapMaybe parsedArgTypeAttr - $ parsedInputs ++ parsedOutputs - inferredTypeAttrs = filter ((`Set.member` argTypeAttrs) . tfName . attrName) - $ mapMaybeAttrs getInferredTypeAttr $ o ^. attr -- Integer attributes that can be inferred from the size of at least one -- input list. inferredListSizeAttrs = mapMaybeAttrs (getInferredListSizeAttr parsedInputs) @@ -235,10 +235,14 @@ parseOp o = ParsedOp implicitAttrs = Set.fromList $ map tfName $ map attrName inferredTypeAttrs ++ map attrName inferredListSizeAttrs - -- Attributes that can't be inferred and don't have defaults, so must be passed - -- as separate arguments to the op. + inferredTypeAttrs = mapMaybeAttrs (getInferredTypeAttr argTypeParams) $ o ^. attr + argTypeParams = Set.fromList $ map tfName $ + mapMaybe (getArgTypeParam . parsedArgCase) $ + parsedInputs ++ parsedOutputs + -- Attributes that can't be inferred and don't have defaults, so must be + -- passed as separate arguments to the op. explicitInputAttrs = sortBy (comparing (tfName . attrName)) - $ mapMaybeAttrs (getExplicitInputAttr implicitAttrs) + $ mapMaybeAttrs (getExplicitInputAttr o implicitAttrs) $ o ^. attr -- TODO(judahjacobson): Some arguments should be refs. @@ -252,29 +256,30 @@ outputTensorKind a | a ^. isRef = ArgTensorRef | otherwise = ArgTensorValue -getExplicitInputAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType -getExplicitInputAttr implicitAttrs a +getExplicitInputAttr :: OpDef -> Set.Set TFName -> OpDef'AttrDef -> Maybe AttrType +getExplicitInputAttr o implicitAttrs a | TFName (a ^. name) `Set.notMember` implicitAttrs , a ^. maybe'defaultValue == Nothing - , t <- parseAttrType (a ^. type') - , t `elem` map AttrSingle [AttrBool, AttrInt64, AttrFloat, AttrShape] = Just t + , t <- parseAttrType o (a ^. type') + , t `elem` map AttrSingle + [AttrBool, AttrInt64, AttrFloat, AttrType, AttrShape] + ++ [AttrList AttrType] = Just t | otherwise = Nothing --- | The type attribute used by this input or output (if any). -parsedArgTypeAttr :: ParsedArg -> Maybe TFName -parsedArgTypeAttr p = case parsedArgCase p of - ResourceArg -> Nothing - SimpleArg {argType = t} -> fromArgType t - ListArg {argType = t} -> fromArgType t - MixedListArg {argTypeAttr = n} -> Just $ tfName n +getInferredTypeAttr :: Set.Set TFName -> OpDef'AttrDef -> Maybe TypeParam +getInferredTypeAttr argTypeParams a + | TFName (a ^. name) `notElem` argTypeParams = Nothing + | a ^. type' == "type" = Just $ TypeParam False allowed + | a ^. type' == "list(type)" = Just $ TypeParam True allowed + | otherwise = Nothing where - fromArgType (ArgTypeAttr n) = Just $ tfName n - fromArgType _ = Nothing + allowed = nonEmpty (a ^. allowedValues . list . type') -getInferredTypeAttr :: OpDef'AttrDef -> Maybe [DataType] -getInferredTypeAttr a - | a ^. type' == "type" = Just $ a ^. allowedValues . list . type' - | otherwise = Nothing +getArgTypeParam :: ParsedArgCase -> Maybe Name +getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n +getArgTypeParam ListArg { argType = ArgTypeAttr n} = Just n +getArgTypeParam MixedListArg { argTypeAttr = n } = Just n +getArgTypeParam _ = Nothing getInferredListSizeAttr :: [ParsedArg] -> OpDef'AttrDef -> Maybe (NonEmpty Name) getInferredListSizeAttr inputs a @@ -285,7 +290,7 @@ getInferredListSizeAttr inputs a } <- inputs , TFName (a ^. name) == tfName n] | otherwise = Nothing - + -- | Like mapMaybe, but associates the attribute name/description with the given info. mapMaybeAttrs :: (OpDef'AttrDef -> Maybe a) -> [OpDef'AttrDef] -> [Attr a] mapMaybeAttrs f = mapMaybe $ \a -> do @@ -295,7 +300,7 @@ mapMaybeAttrs f = mapMaybe $ \a -> do , attrDescription = a ^. description , attrInfo = x } - + parseArg :: OpDef'ArgDef -> ArgKind -> ParsedArg parseArg a tKind = ParsedArg { parsedArgName = makeName (a ^. name) @@ -317,15 +322,15 @@ parseArgCase a tKind maybeAttr "" = Nothing maybeAttr t = Just $ makeName t -parseAttrType :: Text -> AttrType -parseAttrType = \case +parseAttrType :: OpDef -> Text -> AttrType +parseAttrType o = \case "string" -> AttrSingle AttrBytes - "int" -> AttrSingle AttrInt64 - "float" -> AttrSingle AttrFloat - "bool" -> AttrSingle AttrBool - "type" -> AttrSingle AttrType - "shape" -> AttrSingle AttrShape - "tensor" -> AttrSingle AttrTensor + "int" -> AttrSingle AttrInt64 + "float" -> AttrSingle AttrFloat + "bool" -> AttrSingle AttrBool + "type" -> AttrSingle AttrType + "shape" -> AttrSingle AttrShape + "tensor" -> AttrSingle AttrTensor "list(string)" -> AttrList AttrBytes "list(int)" -> AttrList AttrInt64 "list(float)" -> AttrList AttrFloat @@ -334,3 +339,4 @@ parseAttrType = \case "list(shape)" -> AttrList AttrShape "list(tensor)" -> AttrList AttrTensor t -> error $ "parseAttrType: unrecognized type " ++ show t + ++ " for op " ++ show (o ^. name) diff --git a/tensorflow-queue/src/TensorFlow/Queue.hs b/tensorflow-queue/src/TensorFlow/Queue.hs index 0d0ddca..f906f4a 100644 --- a/tensorflow-queue/src/TensorFlow/Queue.hs +++ b/tensorflow-queue/src/TensorFlow/Queue.hs @@ -12,67 +12,65 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -- | Queues in TensorFlow graph. Very limited support for now. -module TensorFlow.Queue (Queue2, makeQueue2, enqueue, dequeue) where +module TensorFlow.Queue (Queue, makeQueue, enqueue, dequeue) where import Data.ByteString (ByteString) import Data.Int (Int64) +import Data.Proxy (Proxy(..)) import Lens.Family2 ((.~), (&)) import TensorFlow.Build (ControlNode, Build, addInitializer, opAttr, opDef) import TensorFlow.BuildOp (buildOp) import TensorFlow.ControlFlow (group) -import TensorFlow.Tensor (Ref, Tensor) -import TensorFlow.Types (TensorType, tensorType) +import TensorFlow.Tensor (Ref, Tensor, TensorList) +import TensorFlow.Types (TensorTypes, fromTensorTypes) --- | A queue carrying tuples. The underlying structure is more --- versatile and can be made to support arbitrary tuples. -data Queue2 a b = Queue2 { handle :: Handle } +-- | A queue carrying tuples. +data Queue (as :: [*]) = Queue { handle :: Handle } type Handle = Tensor Ref ByteString -- | Adds the given values to the queue. -enqueue :: forall a b v1 v2. (TensorType a, TensorType b) - => Queue2 a b - -> Tensor v1 a - -> Tensor v2 b +enqueue :: forall as v . TensorTypes as + => Queue as + -> TensorList v as -> Build ControlNode enqueue q = buildOp (opDef "QueueEnqueue" - & opAttr "Tcomponents" .~ [ tensorType (undefined :: a) - , tensorType (undefined :: b)]) + & opAttr "Tcomponents" .~ fromTensorTypes (Proxy :: Proxy as)) (handle q) -- | Retrieves the values from the queue. -dequeue :: forall a b . (TensorType a, TensorType b) - => Queue2 a b - -> Build (Tensor Ref a, Tensor Ref b) - -- ^ Dequeued tensors. They are paired in a sense +dequeue :: forall as . TensorTypes as + => Queue as + -> Build (TensorList Ref as) + -- ^ Dequeued tensors. They are coupled in a sense -- that values appear together, even if they are -- not consumed together. dequeue q = buildOp (opDef "QueueDequeue" - & opAttr "component_types" .~ [ tensorType (undefined :: a) - , tensorType (undefined :: b)]) + & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as)) (handle q) -- | Creates a new queue with the given capacity and shared name. -makeQueue2 :: forall a b . (TensorType a, TensorType b) +makeQueue :: forall as . TensorTypes as => Int64 -- ^ The upper bound on the number of elements in -- this queue. Negative numbers mean no limit. -> ByteString -- ^ If non-empty, this queue will be shared -- under the given name across multiple sessions. - -> Build (Queue2 a b) -makeQueue2 capacity sharedName = do + -> Build (Queue as) +makeQueue capacity sharedName = do q <- buildOp (opDef "FIFOQueue" - & opAttr "component_types" .~ [ tensorType (undefined :: a) - , tensorType (undefined :: b)] + & opAttr "component_types" .~ fromTensorTypes (Proxy :: Proxy as) & opAttr "shared_name" .~ sharedName & opAttr "capacity" .~ capacity ) group q >>= addInitializer - return (Queue2 q) + return (Queue q) -- TODO(gnezdo): Figure out the closing story for queues. diff --git a/tensorflow-queue/tests/QueueTest.hs b/tensorflow-queue/tests/QueueTest.hs index f3b38eb..5aa0e54 100644 --- a/tensorflow-queue/tests/QueueTest.hs +++ b/tensorflow-queue/tests/QueueTest.hs @@ -12,6 +12,7 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -20,7 +21,7 @@ module Main where import Control.Monad.IO.Class (liftIO) import Data.Int (Int64) import Google.Test (googleTest) -import TensorFlow.Types (Scalar(..)) +import TensorFlow.Types (ListOf(..), Scalar(..), (/:/)) import TensorFlow.Ops (scalar) import TensorFlow.Queue import TensorFlow.Session @@ -39,42 +40,50 @@ import qualified Data.ByteString as BS -- | Test basic queue behaviors. testBasic :: Test testBasic = testCase "testBasic" $ runSession $ do - (q :: Queue2 Int64 BS.ByteString) <- build $ makeQueue2 1 "" - buildAnd run_ (enqueue q 42 (scalar "Hi")) + q :: Queue [Int64, BS.ByteString] <- build $ makeQueue 1 "" + buildAnd run_ $ enqueue q $ 42 :/ scalar "Hi" :/ Nil x <- buildAnd run (dequeue q) - liftIO $ (Scalar 42, Scalar "Hi") @=? x + liftIO $ (Scalar 42 /:/ Scalar "Hi" /:/ Nil) @=? x - buildAnd run_ (enqueue q 56 (scalar "Bar")) + buildAnd run_ $ enqueue q $ 56 :/ scalar "Bar" :/ Nil y <- buildAnd run (dequeue q) - liftIO $ (Scalar 56, Scalar "Bar") @=? y + -- Note: we use explicit "Scalar" here to specify the type that was + -- fetched. Equivalently we could write + -- 56 /:/ "Bar" /:/ Nil :: List [Scalar Int64, Scalar BS.ByteString] + -- or else allow the types to be determined by future use of the fetched + -- value. + let expected = Scalar 56 /:/ Scalar "Bar" /:/ Nil + liftIO $ expected @=? y -- | Test queue pumping. testPump :: Test testPump = testCase "testPump" $ runSession $ do (deq, pump) <- build $ do - q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "ThePumpQueue" + q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "ThePumpQueue" (,) <$> dequeue q - <*> enqueue q 31 (scalar "Baz") + <*> enqueue q (31 :/ scalar "Baz" :/ Nil) -- This is a realistic use. The pump inputs are pre-bound to some -- nodes that produce values when pumped (e.g. read from a -- file). run_ (pump, pump) (x, y) <- run (deq, deq) - liftIO $ (Scalar 31, Scalar "Baz") @=? x - liftIO $ (Scalar 31, Scalar "Baz") @=? y + let expected = Scalar 31 /:/ Scalar "Baz" /:/ Nil + liftIO $ expected @=? x + liftIO $ expected @=? y testAsync :: Test testAsync = testCase "testAsync" $ runSession $ do (deq, pump) <- build $ do - q :: Queue2 Int64 BS.ByteString <- makeQueue2 2 "" + q :: Queue [Int64, BS.ByteString] <- makeQueue 2 "" (,) <$> dequeue q - <*> enqueue q 10 (scalar "Async") + <*> enqueue q (10 :/ scalar "Async" :/ Nil) -- Pumps the queue until canceled by runSession exiting. asyncProdNodes pump -- Picks up a couple values and verifies they are as expected. - run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) - run deq >>= liftIO . ((Scalar 10, Scalar "Async") @=?) + let expected = Scalar 10 /:/ Scalar "Async" /:/ Nil + run deq >>= liftIO . (expected @=?) + run deq >>= liftIO . (expected @=?) main :: IO () main = googleTest [ testBasic diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 6b2df3e..3411e24 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -13,6 +13,8 @@ -- limitations under the License. {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module TensorFlow.BuildOp @@ -33,6 +35,7 @@ import Lens.Family2 ((&), (<>~), (^.)) import TensorFlow.Build import TensorFlow.Output import TensorFlow.Tensor +import TensorFlow.Types data ResultState = ResultState !OutputIx [Int64] deriving Show @@ -98,6 +101,22 @@ instance OpResult (Tensor Ref a) where instance OpResult ControlNode where toResult = ControlNode <$> ask +tensorListResult :: forall as v . TensorTypes as => TensorKind v -> Result (TensorList v as) +tensorListResult v = loop (tensorTypes :: TensorTypeList as) + where + loop :: TensorTypeList bs -> Result (TensorList v bs) + loop Nil = return Nil + loop (TensorTypeProxy :/ ls) = do + t <- tensorResult v + ts <- loop ls + return (t :/ ts) + +instance TensorTypes as => OpResult (TensorList Value as) where + toResult = tensorListResult ValueKind + +instance TensorTypes as => OpResult (TensorList Ref as) where + toResult = tensorListResult RefKind + instance OpResult a => OpResult [a] where toResult = do ResultState i ns <- get @@ -159,6 +178,12 @@ instance BuildOp (Tensor Value a) where instance BuildOp (Tensor Ref a) where buildOp' = pureResult +instance TensorTypes as => BuildOp (TensorList Value as) where + buildOp' = pureResult + +instance TensorTypes as => BuildOp (TensorList Ref as) where + buildOp' = pureResult + instance BuildOp [Tensor Value a] where buildOp' = pureResult @@ -199,6 +224,10 @@ instance BuildOp f => BuildOp ([Tensor v a] -> f) where buildOp' rf o accum ts = buildOp' rf o (reverse (fmap (^. tensorOutput) ts) ++ accum) +instance BuildOp f => BuildOp (TensorList v as -> f) where + buildOp' rf o accum ts + = buildOp' rf o (reverse (tensorListOutputs ts) ++ accum) + -- | Returns true if all the integers in each tuple are identical. -- Throws an error with a descriptive message if not. eqLengthGuard :: [(String, [(String, Int)])] -> Bool diff --git a/tensorflow/src/TensorFlow/Nodes.hs b/tensorflow/src/TensorFlow/Nodes.hs index 5e8c62d..a7ce925 100644 --- a/tensorflow/src/TensorFlow/Nodes.hs +++ b/tensorflow/src/TensorFlow/Nodes.hs @@ -12,15 +12,18 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module TensorFlow.Nodes where import Control.Applicative (liftA2, liftA3) +import Data.Functor.Identity (Identity) import Data.Map.Strict (Map) import Data.Monoid ((<>)) import Data.Set (Set) @@ -96,6 +99,19 @@ instance Nodes ControlNode where instance a ~ () => Fetchable ControlNode a where getFetch _ = return $ pure () +instance Nodes (ListOf f '[]) where + getNodes _ = return Set.empty + +instance (Nodes (f a), Nodes (ListOf f as)) => Nodes (ListOf f (a ': as)) where + getNodes (x :/ xs) = liftA2 Set.union (getNodes x) (getNodes xs) + +instance l ~ List '[] => Fetchable (ListOf f '[]) l where + getFetch _ = return $ pure Nil + +instance (Fetchable (f t) a, Fetchable (ListOf f ts) (List as), i ~ Identity) + => Fetchable (ListOf f (t ': ts)) (ListOf i (a ': as)) where + getFetch (x :/ xs) = liftA2 (\y ys -> y /:/ ys) <$> getFetch x <*> getFetch xs + instance Nodes (Tensor v a) where getNodes t = Set.singleton <$> getOrAddOp (t ^. tensorOutput . outputOp) diff --git a/tensorflow/src/TensorFlow/Tensor.hs b/tensorflow/src/TensorFlow/Tensor.hs index da03184..7d6ca4f 100644 --- a/tensorflow/src/TensorFlow/Tensor.hs +++ b/tensorflow/src/TensorFlow/Tensor.hs @@ -12,20 +12,30 @@ -- See the License for the specific language governing permissions and -- limitations under the License. +{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE Rank2Types #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} module TensorFlow.Tensor where import Data.String (IsString(..)) import qualified Data.Text as Text -import Lens.Family2 (Lens', Traversal') +import Lens.Family2 (Lens', Traversal', (^.)) import Lens.Family2.Unchecked (lens) import TensorFlow.Output (Output, outputOp, opUnrendered, opAttr) -import TensorFlow.Types (TensorData(..), Attribute) +import TensorFlow.Types + ( TensorData(..) + , Attribute + , ListOf(..) + ) import qualified TensorFlow.Internal.FFI as FFI -- | A named output of a TensorFlow operation. @@ -83,3 +93,9 @@ feed (Tensor _ o) (TensorData td) = Feed o td -- TODO(judahjacobson): add more safety checks here. tensorFromName :: TensorKind v -> Text.Text -> Tensor v a tensorFromName v = Tensor v . fromString . Text.unpack + +type TensorList v = ListOf (Tensor v) + +tensorListOutputs :: TensorList v as -> [Output] +tensorListOutputs Nil = [] +tensorListOutputs (t :/ ts) = (t ^. tensorOutput) : tensorListOutputs ts diff --git a/tensorflow/src/TensorFlow/Types.hs b/tensorflow/src/TensorFlow/Types.hs index 3ed9cec..b5fd115 100644 --- a/tensorflow/src/TensorFlow/Types.hs +++ b/tensorflow/src/TensorFlow/Types.hs @@ -13,9 +13,11 @@ -- limitations under the License. {-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} @@ -36,23 +38,35 @@ module TensorFlow.Types , Shape(..) , protoShape , Attribute(..) + , DataType(..) + -- * Lists + , ListOf(..) + , List + , (/:/) + , TensorTypeProxy(..) + , TensorTypes(..) + , TensorTypeList + , fromTensorTypeList + , fromTensorTypes -- * Type constraints , OneOf , type (/=) + , OneOfs -- ** Implementation of constraints , TypeError , ExcludedCase - , TensorTypes , NoneOf , type (\\) , Delete , AllTensorTypes ) where +import Data.Functor.Identity (Identity(..)) import Data.Complex (Complex) import Data.Default (def) import Data.Int (Int8, Int16, Int32, Int64) import Data.Monoid ((<>)) +import Data.Proxy (Proxy(..)) import Data.String (IsString) import Data.Word (Word8, Word16, Word64) import Foreign.Storable (Storable) @@ -376,6 +390,44 @@ instance Attribute [DataType] where instance Attribute [Int64] where attrLens = list . i +-- | A heterogeneous list type. +data ListOf f as where + Nil :: ListOf f '[] + (:/) :: f a -> ListOf f as -> ListOf f (a ': as) + +infixr 5 :/ + +type family All f as :: Constraint where + All f '[] = () + All f (a ': as) = (f a, All f as) + +type family Map f as where + Map f '[] = '[] + Map f (a ': as) = f a ': Map f as + +instance All Eq (Map f as) => Eq (ListOf f as) where + Nil == Nil = True + (x :/ xs) == (y :/ ys) = x == y && xs == ys + -- Newer versions of GHC use the GADT to tell that the previous cases are + -- exhaustive. +#if _GLASGOW_HASKELL__ < 800 + _ == _ = False +#endif + +instance All Show (Map f as) => Show (ListOf f as) where + showsPrec _ Nil = showString "Nil" + showsPrec d (x :/ xs) = showParen (d > 10) + $ showsPrec 6 x . showString " :/ " + . showsPrec 6 xs + +type List = ListOf Identity + +-- | Equivalent of ':/' for lists. +(/:/) :: a -> List as -> List (a ': as) +(/:/) = (:/) . Identity + +infixr 5 /:/ + -- | A 'Constraint' specifying the possible choices of a 'TensorType'. -- -- We implement a 'Constraint' like @OneOf '[Double, Float] a@ by turning the @@ -393,13 +445,38 @@ instance Attribute [Int64] where -- -- using an enumeration of all the possible 'TensorType's. type OneOf ts a + -- Assert `TensorTypes ts` to make error messages a little better. = (TensorType a, TensorTypes ts, NoneOf (AllTensorTypes \\ ts) a) --- | A check that the input is a list of 'TensorType's. --- Helps improve error messages when using 'OneOf'. +type OneOfs ts as = (TensorTypes as, TensorTypes ts, + NoneOfs (AllTensorTypes \\ ts) as) + +type family NoneOfs ts as :: Constraint where + NoneOfs ts '[] = () + NoneOfs ts (a ': as) = (NoneOf ts a, NoneOfs ts as) + +data TensorTypeProxy a where + TensorTypeProxy :: TensorType a => TensorTypeProxy a + +type TensorTypeList = ListOf TensorTypeProxy + +fromTensorTypeList :: TensorTypeList ts -> [DataType] +fromTensorTypeList Nil = [] +fromTensorTypeList ((TensorTypeProxy :: TensorTypeProxy t) :/ ts) + = tensorType (undefined :: t) : fromTensorTypeList ts + +fromTensorTypes :: forall as . TensorTypes as => Proxy as -> [DataType] +fromTensorTypes _ = fromTensorTypeList (tensorTypes :: TensorTypeList as) + class TensorTypes (ts :: [*]) where -instance TensorTypes '[] -instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) + tensorTypes :: TensorTypeList ts + +instance TensorTypes '[] where + tensorTypes = Nil + +-- | A constraint that the input is a list of 'TensorTypes'. +instance (TensorType t, TensorTypes ts) => TensorTypes (t ': ts) where + tensorTypes = TensorTypeProxy :/ tensorTypes -- | A constraint checking that two types are different. type family a /= b :: Constraint where