1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-24 18:49:46 +01:00

Support ResourceHandle. (#18)

Exposed by moving to newer TF.
This commit is contained in:
Greg Steuck 2016-11-08 16:48:41 -08:00 committed by Judah Jacobson
parent 29f11d351d
commit 8db944578a
4 changed files with 67 additions and 17 deletions

View file

@ -25,10 +25,11 @@ module TensorFlow.OpGen
import Prelude hiding (head, tail)
import Control.Applicative ((<**>))
import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList)
import Data.Maybe (fromMaybe, maybeToList)
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head)
import qualified Data.List.NonEmpty as NE
@ -119,6 +120,10 @@ docOpList flags opList =
, "{-# LANGUAGE OverloadedStrings #-}"
, "{-# LANGUAGE RankNTypes #-}"
, "{-# LANGUAGE ScopedTypeVariables #-}"
-- Avoids reports about shadowing standard library names.
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
-- eqLengthGuard never returns false and dies instead.
, "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
, "module" <+> strictText moduleName <+> "where"
, empty
, imports
@ -135,6 +140,7 @@ docOpList flags opList =
shortName = Text.pack (takeBaseName $ outputFile flags)
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
camelCase :: Text -> Text
camelCase s = Text.concat $ map upCase
$ filter (/= "ops")
$ Text.splitOn "_" s
@ -152,6 +158,7 @@ forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
(Text.uncons s)
imports :: Doc
imports = stack [
"import Data.ByteString (ByteString)"
, "import Data.Complex (Complex)"
@ -160,6 +167,7 @@ imports = stack [
, "import Lens.Family2 ((.~), (&))"
, "import TensorFlow.Build"
, "import TensorFlow.BuildOp"
, "import TensorFlow.Output (ResourceHandle)"
, "import TensorFlow.Tensor"
, "import TensorFlow.Types"
]
@ -237,40 +245,57 @@ quotedText = dquotes . strictText
typeSig :: OpDef -> Doc
typeSig d =
foralls <+> constraints <+/>
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
where
foralls | Map.null typeMap = empty
foralls | null typeMap = empty
| otherwise =
"forall"
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
<+> sep (refVariableNames ++ typeMapTypeNames)
<+> "."
constraints | Map.null typeMap = empty
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
constraints | null typeMap = empty
| otherwise =
tuple (concatMap
(\(t, aDef) ->
"TensorType" <+> strictText t
: maybeToList (oneOfRestrictions aDef t))
(Map.elems typeMap)) <+> "=>"
refVariableNames = catMaybes (map fst tensorInputs)
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
tensorArg refType arg = wrapArg refType arg <+>
hang 0 ("-- ^" <+> argComment arg)
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
tensorArg refType arg = wrapArg refType arg <**>
pure (<+> hang 0 ("-- ^" <+> argComment arg))
-- Argument type is a list of tensors if number_attr is set;
-- otherwise it's a single Tensor.
wrapArg refType arg =
if Text.null (arg ^. numberAttr) then typ else brackets typ
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
where typ = tensorType refType arg
tensorType refType arg =
"Tensor" <+> refType <+> maybe directType strictText indirectType
-- The result is (reference type variable if any, type representing the arg)
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
tensorType refType arg
-- Deals with resource handles that are unlike tensors and
-- have their own type level representation. The magic "dtype"
-- name is the name of the attribute specifying the type of
-- the resource handle. It is not referenced by resource input
-- or output because it isn't expressible in TF operation
-- signatures.
| (arg ^. type' == DT_RESOURCE) =
(Nothing, strictText "ResourceHandle dtype")
| otherwise =
(Just refType,
"Tensor" <+> refType <+> maybe directType strictText indirectType)
where
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
-- This is the case of a named parameter type which is
-- constrained as a OneOf.
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
-- The nominal case when the type name is given directly.
directType = dtTypeToDoc (arg ^. type')
outputs =
case d ^. outputArg of
[] -> "ControlNode"
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
os -> renderTupleResult os
wrappedOutput = wrapArg "Value"
wrappedOutput = snd . wrapArg "Value"
-- Tuple result case is rendered differently to give
-- individual elements their own comments.
renderTupleResult os =
@ -324,7 +349,11 @@ argComment' argName argDesc =
bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__")
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
-- | Returns the map of type parameters from OpDef type name to
-- (haskell friendly type name, the type's attribute definition).
opDefTypeMap :: OpDef -> OpDefTypeMap
opDefTypeMap d =
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
@ -390,6 +419,8 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell DT_RESOURCE =
error "ResourceHandle must be prevented from getting here."
dtTypeToHaskell x =
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
@ -419,6 +450,7 @@ replaceReservedName n
| n `Set.member` reservedKeywords = n <> "'"
| otherwise = n
indentation :: Int
indentation = 4
reservedKeywords :: Set.Set Text

View file

@ -77,11 +77,17 @@ instance ( OpResult a1
<*> toResult
tensorResult :: TensorKind v -> Result (Tensor v a)
tensorResult v = do
tensorResult v = Tensor v <$> recordResult
recordResult :: Result Output
recordResult = do
o <- ask
ResultState i ns <- get
put $! ResultState (i+1) ns
return $! Tensor v $ output i o
return $! output i o
instance OpResult (ResourceHandle a) where
toResult = ResourceHandle <$> recordResult
instance OpResult (Tensor Value a) where
toResult = tensorResult ValueKind
@ -144,6 +150,9 @@ buildListOp counts o = buildOp' counts o []
instance BuildOp ControlNode where
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
instance BuildOp (ResourceHandle a) where
buildOp' = pureResult
instance BuildOp (Tensor Value a) where
buildOp' = pureResult
@ -180,6 +189,9 @@ instance ( OpResult t1
instance OpResult a => BuildOp (Build a) where
buildOp' = buildResult
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
instance BuildOp f => BuildOp (Tensor v a -> f) where
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)

View file

@ -36,6 +36,7 @@ module TensorFlow.Output
, outputIndex
, outputOp
, PendingNodeName(..)
, ResourceHandle(..)
) where
import qualified Data.Map.Strict as Map
@ -154,3 +155,8 @@ instance IsString Output where
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n
-- | Opaque handle to a mutable resource in the graph. Typical such
-- resources are variables. The type parameter corresponds to the
-- dtype of the tensor held in the variable.
newtype ResourceHandle a = ResourceHandle Output

@ -1 +1 @@
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385
Subproject commit e1c7e510a569cd5898f08015352bbdc8bef2ff7e