1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-23 11:29:43 +01:00

Support ResourceHandle. (#18)

Exposed by moving to newer TF.
This commit is contained in:
Greg Steuck 2016-11-08 16:48:41 -08:00 committed by Judah Jacobson
parent 29f11d351d
commit 8db944578a
4 changed files with 67 additions and 17 deletions

View file

@ -25,10 +25,11 @@ module TensorFlow.OpGen
import Prelude hiding (head, tail) import Prelude hiding (head, tail)
import Control.Applicative ((<**>))
import Control.Monad (guard) import Control.Monad (guard)
import Data.Char (toLower, toUpper) import Data.Char (toLower, toUpper)
import Data.Foldable (toList) import Data.Foldable (toList)
import Data.Maybe (fromMaybe, maybeToList) import Data.Maybe (catMaybes, fromMaybe, maybeToList)
import Data.ProtoLens (def, showMessage) import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head) import Data.List.NonEmpty (NonEmpty((:|)), head)
import qualified Data.List.NonEmpty as NE import qualified Data.List.NonEmpty as NE
@ -119,6 +120,10 @@ docOpList flags opList =
, "{-# LANGUAGE OverloadedStrings #-}" , "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}" , "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}" , "{-# LANGUAGE ScopedTypeVariables #-}"
-- Avoids reports about shadowing standard library names.
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
-- eqLengthGuard never returns false and dies instead.
, "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
, "module" <+> strictText moduleName <+> "where" , "module" <+> strictText moduleName <+> "where"
, empty , empty
, imports , imports
@ -135,6 +140,7 @@ docOpList flags opList =
shortName = Text.pack (takeBaseName $ outputFile flags) shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase camelCase s = Text.concat $ map upCase
$ filter (/= "ops") $ filter (/= "ops")
$ Text.splitOn "_" s $ Text.splitOn "_" s
@ -152,6 +158,7 @@ forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s) (Text.uncons s)
imports :: Doc
imports = stack [ imports = stack [
"import Data.ByteString (ByteString)" "import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)" , "import Data.Complex (Complex)"
@ -160,6 +167,7 @@ imports = stack [
, "import Lens.Family2 ((.~), (&))" , "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build" , "import TensorFlow.Build"
, "import TensorFlow.BuildOp" , "import TensorFlow.BuildOp"
, "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor" , "import TensorFlow.Tensor"
, "import TensorFlow.Types" , "import TensorFlow.Types"
] ]
@ -237,40 +245,57 @@ quotedText = dquotes . strictText
typeSig :: OpDef -> Doc typeSig :: OpDef -> Doc
typeSig d = typeSig d =
foralls <+> constraints <+/> foralls <+> constraints <+/>
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs]) signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
where where
foralls | Map.null typeMap = empty foralls | null typeMap = empty
| otherwise = | otherwise =
"forall" "forall"
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap)) <+> sep (refVariableNames ++ typeMapTypeNames)
<+> "." <+> "."
constraints | Map.null typeMap = empty typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
constraints | null typeMap = empty
| otherwise = | otherwise =
tuple (concatMap tuple (concatMap
(\(t, aDef) -> (\(t, aDef) ->
"TensorType" <+> strictText t "TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t)) : maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>" (Map.elems typeMap)) <+> "=>"
refVariableNames = catMaybes (map fst tensorInputs)
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg) tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)] refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
tensorArg refType arg = wrapArg refType arg <+> tensorArg refType arg = wrapArg refType arg <**>
hang 0 ("-- ^" <+> argComment arg) pure (<+> hang 0 ("-- ^" <+> argComment arg))
-- Argument type is a list of tensors if number_attr is set; -- Argument type is a list of tensors if number_attr is set;
-- otherwise it's a single Tensor. -- otherwise it's a single Tensor.
wrapArg refType arg = wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets typ if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
where typ = tensorType refType arg where typ = tensorType refType arg
tensorType refType arg = -- The result is (reference type variable if any, type representing the arg)
"Tensor" <+> refType <+> maybe directType strictText indirectType tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
tensorType refType arg
-- Deals with resource handles that are unlike tensors and
-- have their own type level representation. The magic "dtype"
-- name is the name of the attribute specifying the type of
-- the resource handle. It is not referenced by resource input
-- or output because it isn't expressible in TF operation
-- signatures.
| (arg ^. type' == DT_RESOURCE) =
(Nothing, strictText "ResourceHandle dtype")
| otherwise =
(Just refType,
"Tensor" <+> refType <+> maybe directType strictText indirectType)
where where
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap) -- This is the case of a named parameter type which is
-- constrained as a OneOf.
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
-- The nominal case when the type name is given directly.
directType = dtTypeToDoc (arg ^. type') directType = dtTypeToDoc (arg ^. type')
outputs = outputs =
case d ^. outputArg of case d ^. outputArg of
[] -> "ControlNode" [] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o [o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os os -> renderTupleResult os
wrappedOutput = wrapArg "Value" wrappedOutput = snd . wrapArg "Value"
-- Tuple result case is rendered differently to give -- Tuple result case is rendered differently to give
-- individual elements their own comments. -- individual elements their own comments.
renderTupleResult os = renderTupleResult os =
@ -324,7 +349,11 @@ argComment' argName argDesc =
bold :: Text.Text -> Doc bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__") bold n = strictText ("__" <> n <> "__")
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef) type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
-- | Returns the map of type parameters from OpDef type name to
-- (haskell friendly type name, the type's attribute definition).
opDefTypeMap :: OpDef -> OpDefTypeMap
opDefTypeMap d = opDefTypeMap d =
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a] Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
@ -390,6 +419,8 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16" dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8" dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE =
error "ResourceHandle must be prevented from getting here."
dtTypeToHaskell x = dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
@ -419,6 +450,7 @@ replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'" | n `Set.member` reservedKeywords = n <> "'"
| otherwise = n | otherwise = n
indentation :: Int
indentation = 4 indentation = 4
reservedKeywords :: Set.Set Text reservedKeywords :: Set.Set Text

View file

@ -77,11 +77,17 @@ instance ( OpResult a1
<*> toResult <*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a) tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = do tensorResult v = Tensor v <$> recordResult
recordResult :: Result Output
recordResult = do
o <- ask o <- ask
ResultState i ns <- get ResultState i ns <- get
put $! ResultState (i+1) ns put $! ResultState (i+1) ns
return $! Tensor v $ output i o return $! output i o
instance OpResult (ResourceHandle a) where
toResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind toResult = tensorResult ValueKind
@ -144,6 +150,9 @@ buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (ResourceHandle a) where
buildOp' = pureResult
instance BuildOp (Tensor Value a) where instance BuildOp (Tensor Value a) where
buildOp' = pureResult buildOp' = pureResult
@ -180,6 +189,9 @@ instance ( OpResult t1
instance OpResult a => BuildOp (Build a) where instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts) buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)

View file

@ -36,6 +36,7 @@ module TensorFlow.Output
, outputIndex , outputIndex
, outputOp , outputOp
, PendingNodeName(..) , PendingNodeName(..)
, ResourceHandle(..)
) where ) where
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
@ -154,3 +155,8 @@ instance IsString Output where
_ -> Output 0 $ assigned s _ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n where assigned n = Rendered $ def & name .~ Text.pack n
-- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables. The type parameter corresponds to the
-- dtype of the tensor held in the variable.
newtype ResourceHandle a = ResourceHandle Output

@ -1 +1 @@
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385 Subproject commit e1c7e510a569cd5898f08015352bbdc8bef2ff7e