From 8db944578a16526ed8645bf3903dbf7c465aa9e2 Mon Sep 17 00:00:00 2001 From: Greg Steuck Date: Tue, 8 Nov 2016 16:48:41 -0800 Subject: [PATCH] Support ResourceHandle. (#18) Exposed by moving to newer TF. --- tensorflow-opgen/src/TensorFlow/OpGen.hs | 60 ++++++++++++++++++------ tensorflow/src/TensorFlow/BuildOp.hs | 16 ++++++- tensorflow/src/TensorFlow/Output.hs | 6 +++ third_party/tensorflow | 2 +- 4 files changed, 67 insertions(+), 17 deletions(-) diff --git a/tensorflow-opgen/src/TensorFlow/OpGen.hs b/tensorflow-opgen/src/TensorFlow/OpGen.hs index bf0fc6e..446e049 100644 --- a/tensorflow-opgen/src/TensorFlow/OpGen.hs +++ b/tensorflow-opgen/src/TensorFlow/OpGen.hs @@ -25,10 +25,11 @@ module TensorFlow.OpGen import Prelude hiding (head, tail) +import Control.Applicative ((<**>)) import Control.Monad (guard) import Data.Char (toLower, toUpper) import Data.Foldable (toList) -import Data.Maybe (fromMaybe, maybeToList) +import Data.Maybe (catMaybes, fromMaybe, maybeToList) import Data.ProtoLens (def, showMessage) import Data.List.NonEmpty (NonEmpty((:|)), head) import qualified Data.List.NonEmpty as NE @@ -119,6 +120,10 @@ docOpList flags opList = , "{-# LANGUAGE OverloadedStrings #-}" , "{-# LANGUAGE RankNTypes #-}" , "{-# LANGUAGE ScopedTypeVariables #-}" + -- Avoids reports about shadowing standard library names. + , "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}" + -- eqLengthGuard never returns false and dies instead. + , "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}" , "module" <+> strictText moduleName <+> "where" , empty , imports @@ -135,6 +140,7 @@ docOpList flags opList = shortName = Text.pack (takeBaseName $ outputFile flags) exclusions = Text.splitOn "," $ Text.pack $ excludeList flags +camelCase :: Text -> Text camelCase s = Text.concat $ map upCase $ filter (/= "ops") $ Text.splitOn "_" s @@ -152,6 +158,7 @@ forceCase :: (Char -> Char) -> Text -> Text forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs) (Text.uncons s) +imports :: Doc imports = stack [ "import Data.ByteString (ByteString)" , "import Data.Complex (Complex)" @@ -160,6 +167,7 @@ imports = stack [ , "import Lens.Family2 ((.~), (&))" , "import TensorFlow.Build" , "import TensorFlow.BuildOp" + , "import TensorFlow.Output (ResourceHandle)" , "import TensorFlow.Tensor" , "import TensorFlow.Types" ] @@ -237,40 +245,57 @@ quotedText = dquotes . strictText typeSig :: OpDef -> Doc typeSig d = foralls <+> constraints <+/> - signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs]) + signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs]) where - foralls | Map.null typeMap = empty + foralls | null typeMap = empty | otherwise = "forall" - <+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap)) + <+> sep (refVariableNames ++ typeMapTypeNames) <+> "." - constraints | Map.null typeMap = empty + typeMapTypeNames = map (strictText . fst) (Map.elems typeMap) + constraints | null typeMap = empty | otherwise = tuple (concatMap (\(t, aDef) -> "TensorType" <+> strictText t : maybeToList (oneOfRestrictions aDef t)) (Map.elems typeMap)) <+> "=>" + refVariableNames = catMaybes (map fst tensorInputs) tensorInputs = zipWith tensorArg refTypes (d ^. inputArg) - refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)] - tensorArg refType arg = wrapArg refType arg <+> - hang 0 ("-- ^" <+> argComment arg) + refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]] + tensorArg refType arg = wrapArg refType arg <**> + pure (<+> hang 0 ("-- ^" <+> argComment arg)) -- Argument type is a list of tensors if number_attr is set; -- otherwise it's a single Tensor. wrapArg refType arg = - if Text.null (arg ^. numberAttr) then typ else brackets typ + if Text.null (arg ^. numberAttr) then typ else brackets <$> typ where typ = tensorType refType arg - tensorType refType arg = - "Tensor" <+> refType <+> maybe directType strictText indirectType + -- The result is (reference type variable if any, type representing the arg) + tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc) + tensorType refType arg + -- Deals with resource handles that are unlike tensors and + -- have their own type level representation. The magic "dtype" + -- name is the name of the attribute specifying the type of + -- the resource handle. It is not referenced by resource input + -- or output because it isn't expressible in TF operation + -- signatures. + | (arg ^. type' == DT_RESOURCE) = + (Nothing, strictText "ResourceHandle dtype") + | otherwise = + (Just refType, + "Tensor" <+> refType <+> maybe directType strictText indirectType) where - indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap) + -- This is the case of a named parameter type which is + -- constrained as a OneOf. + indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap) + -- The nominal case when the type name is given directly. directType = dtTypeToDoc (arg ^. type') outputs = case d ^. outputArg of [] -> "ControlNode" [o] -> wrappedOutput o <+> "-- ^" <+> argComment o os -> renderTupleResult os - wrappedOutput = wrapArg "Value" + wrappedOutput = snd . wrapArg "Value" -- Tuple result case is rendered differently to give -- individual elements their own comments. renderTupleResult os = @@ -324,7 +349,11 @@ argComment' argName argDesc = bold :: Text.Text -> Doc bold n = strictText ("__" <> n <> "__") -opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef) +type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef) + +-- | Returns the map of type parameters from OpDef type name to +-- (haskell friendly type name, the type's attribute definition). +opDefTypeMap :: OpDef -> OpDefTypeMap opDefTypeMap d = Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a] @@ -390,6 +419,8 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString" dtTypeToHaskell DT_UINT16 = "Data.Word.Word16" dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique dtTypeToHaskell DT_UINT8 = "Data.Word.Word8" +dtTypeToHaskell DT_RESOURCE = + error "ResourceHandle must be prevented from getting here." dtTypeToHaskell x = Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x @@ -419,6 +450,7 @@ replaceReservedName n | n `Set.member` reservedKeywords = n <> "'" | otherwise = n +indentation :: Int indentation = 4 reservedKeywords :: Set.Set Text diff --git a/tensorflow/src/TensorFlow/BuildOp.hs b/tensorflow/src/TensorFlow/BuildOp.hs index 9a96ced..3d6c675 100644 --- a/tensorflow/src/TensorFlow/BuildOp.hs +++ b/tensorflow/src/TensorFlow/BuildOp.hs @@ -77,11 +77,17 @@ instance ( OpResult a1 <*> toResult tensorResult :: TensorKind v -> Result (Tensor v a) -tensorResult v = do +tensorResult v = Tensor v <$> recordResult + +recordResult :: Result Output +recordResult = do o <- ask ResultState i ns <- get put $! ResultState (i+1) ns - return $! Tensor v $ output i o + return $! output i o + +instance OpResult (ResourceHandle a) where + toResult = ResourceHandle <$> recordResult instance OpResult (Tensor Value a) where toResult = tensorResult ValueKind @@ -144,6 +150,9 @@ buildListOp counts o = buildOp' counts o [] instance BuildOp ControlNode where buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts +instance BuildOp (ResourceHandle a) where + buildOp' = pureResult + instance BuildOp (Tensor Value a) where buildOp' = pureResult @@ -180,6 +189,9 @@ instance ( OpResult t1 instance OpResult a => BuildOp (Build a) where buildOp' = buildResult +instance BuildOp f => BuildOp (ResourceHandle a -> f) where + buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts) + instance BuildOp f => BuildOp (Tensor v a -> f) where buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts) diff --git a/tensorflow/src/TensorFlow/Output.hs b/tensorflow/src/TensorFlow/Output.hs index 6bee40a..b05eef9 100644 --- a/tensorflow/src/TensorFlow/Output.hs +++ b/tensorflow/src/TensorFlow/Output.hs @@ -36,6 +36,7 @@ module TensorFlow.Output , outputIndex , outputOp , PendingNodeName(..) + , ResourceHandle(..) ) where import qualified Data.Map.Strict as Map @@ -154,3 +155,8 @@ instance IsString Output where _ -> Output 0 $ assigned s where assigned n = Rendered $ def & name .~ Text.pack n + +-- | Opaque handle to a mutable resource in the graph. Typical such +-- resources are variables. The type parameter corresponds to the +-- dtype of the tensor held in the variable. +newtype ResourceHandle a = ResourceHandle Output diff --git a/third_party/tensorflow b/third_party/tensorflow index bac7faa..e1c7e51 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385 +Subproject commit e1c7e510a569cd5898f08015352bbdc8bef2ff7e