mirror of
https://github.com/tensorflow/haskell.git
synced 2024-12-24 18:49:46 +01:00
parent
29f11d351d
commit
8db944578a
4 changed files with 67 additions and 17 deletions
|
@ -25,10 +25,11 @@ module TensorFlow.OpGen
|
|||
|
||||
import Prelude hiding (head, tail)
|
||||
|
||||
import Control.Applicative ((<**>))
|
||||
import Control.Monad (guard)
|
||||
import Data.Char (toLower, toUpper)
|
||||
import Data.Foldable (toList)
|
||||
import Data.Maybe (fromMaybe, maybeToList)
|
||||
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
|
||||
import Data.ProtoLens (def, showMessage)
|
||||
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
||||
import qualified Data.List.NonEmpty as NE
|
||||
|
@ -119,6 +120,10 @@ docOpList flags opList =
|
|||
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||
, "{-# LANGUAGE RankNTypes #-}"
|
||||
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||
-- Avoids reports about shadowing standard library names.
|
||||
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
|
||||
-- eqLengthGuard never returns false and dies instead.
|
||||
, "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
|
||||
, "module" <+> strictText moduleName <+> "where"
|
||||
, empty
|
||||
, imports
|
||||
|
@ -135,6 +140,7 @@ docOpList flags opList =
|
|||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||
|
||||
camelCase :: Text -> Text
|
||||
camelCase s = Text.concat $ map upCase
|
||||
$ filter (/= "ops")
|
||||
$ Text.splitOn "_" s
|
||||
|
@ -152,6 +158,7 @@ forceCase :: (Char -> Char) -> Text -> Text
|
|||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||
(Text.uncons s)
|
||||
|
||||
imports :: Doc
|
||||
imports = stack [
|
||||
"import Data.ByteString (ByteString)"
|
||||
, "import Data.Complex (Complex)"
|
||||
|
@ -160,6 +167,7 @@ imports = stack [
|
|||
, "import Lens.Family2 ((.~), (&))"
|
||||
, "import TensorFlow.Build"
|
||||
, "import TensorFlow.BuildOp"
|
||||
, "import TensorFlow.Output (ResourceHandle)"
|
||||
, "import TensorFlow.Tensor"
|
||||
, "import TensorFlow.Types"
|
||||
]
|
||||
|
@ -237,40 +245,57 @@ quotedText = dquotes . strictText
|
|||
typeSig :: OpDef -> Doc
|
||||
typeSig d =
|
||||
foralls <+> constraints <+/>
|
||||
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
|
||||
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
|
||||
where
|
||||
foralls | Map.null typeMap = empty
|
||||
foralls | null typeMap = empty
|
||||
| otherwise =
|
||||
"forall"
|
||||
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
|
||||
<+> sep (refVariableNames ++ typeMapTypeNames)
|
||||
<+> "."
|
||||
constraints | Map.null typeMap = empty
|
||||
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
|
||||
constraints | null typeMap = empty
|
||||
| otherwise =
|
||||
tuple (concatMap
|
||||
(\(t, aDef) ->
|
||||
"TensorType" <+> strictText t
|
||||
: maybeToList (oneOfRestrictions aDef t))
|
||||
(Map.elems typeMap)) <+> "=>"
|
||||
refVariableNames = catMaybes (map fst tensorInputs)
|
||||
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
||||
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
|
||||
tensorArg refType arg = wrapArg refType arg <+>
|
||||
hang 0 ("-- ^" <+> argComment arg)
|
||||
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
|
||||
tensorArg refType arg = wrapArg refType arg <**>
|
||||
pure (<+> hang 0 ("-- ^" <+> argComment arg))
|
||||
-- Argument type is a list of tensors if number_attr is set;
|
||||
-- otherwise it's a single Tensor.
|
||||
wrapArg refType arg =
|
||||
if Text.null (arg ^. numberAttr) then typ else brackets typ
|
||||
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
|
||||
where typ = tensorType refType arg
|
||||
tensorType refType arg =
|
||||
"Tensor" <+> refType <+> maybe directType strictText indirectType
|
||||
-- The result is (reference type variable if any, type representing the arg)
|
||||
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
|
||||
tensorType refType arg
|
||||
-- Deals with resource handles that are unlike tensors and
|
||||
-- have their own type level representation. The magic "dtype"
|
||||
-- name is the name of the attribute specifying the type of
|
||||
-- the resource handle. It is not referenced by resource input
|
||||
-- or output because it isn't expressible in TF operation
|
||||
-- signatures.
|
||||
| (arg ^. type' == DT_RESOURCE) =
|
||||
(Nothing, strictText "ResourceHandle dtype")
|
||||
| otherwise =
|
||||
(Just refType,
|
||||
"Tensor" <+> refType <+> maybe directType strictText indirectType)
|
||||
where
|
||||
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
|
||||
-- This is the case of a named parameter type which is
|
||||
-- constrained as a OneOf.
|
||||
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
|
||||
-- The nominal case when the type name is given directly.
|
||||
directType = dtTypeToDoc (arg ^. type')
|
||||
outputs =
|
||||
case d ^. outputArg of
|
||||
[] -> "ControlNode"
|
||||
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
||||
os -> renderTupleResult os
|
||||
wrappedOutput = wrapArg "Value"
|
||||
wrappedOutput = snd . wrapArg "Value"
|
||||
-- Tuple result case is rendered differently to give
|
||||
-- individual elements their own comments.
|
||||
renderTupleResult os =
|
||||
|
@ -324,7 +349,11 @@ argComment' argName argDesc =
|
|||
bold :: Text.Text -> Doc
|
||||
bold n = strictText ("__" <> n <> "__")
|
||||
|
||||
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
|
||||
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
|
||||
|
||||
-- | Returns the map of type parameters from OpDef type name to
|
||||
-- (haskell friendly type name, the type's attribute definition).
|
||||
opDefTypeMap :: OpDef -> OpDefTypeMap
|
||||
opDefTypeMap d =
|
||||
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
||||
|
||||
|
@ -390,6 +419,8 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
|||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||
dtTypeToHaskell DT_RESOURCE =
|
||||
error "ResourceHandle must be prevented from getting here."
|
||||
dtTypeToHaskell x =
|
||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||
|
||||
|
@ -419,6 +450,7 @@ replaceReservedName n
|
|||
| n `Set.member` reservedKeywords = n <> "'"
|
||||
| otherwise = n
|
||||
|
||||
indentation :: Int
|
||||
indentation = 4
|
||||
|
||||
reservedKeywords :: Set.Set Text
|
||||
|
|
|
@ -77,11 +77,17 @@ instance ( OpResult a1
|
|||
<*> toResult
|
||||
|
||||
tensorResult :: TensorKind v -> Result (Tensor v a)
|
||||
tensorResult v = do
|
||||
tensorResult v = Tensor v <$> recordResult
|
||||
|
||||
recordResult :: Result Output
|
||||
recordResult = do
|
||||
o <- ask
|
||||
ResultState i ns <- get
|
||||
put $! ResultState (i+1) ns
|
||||
return $! Tensor v $ output i o
|
||||
return $! output i o
|
||||
|
||||
instance OpResult (ResourceHandle a) where
|
||||
toResult = ResourceHandle <$> recordResult
|
||||
|
||||
instance OpResult (Tensor Value a) where
|
||||
toResult = tensorResult ValueKind
|
||||
|
@ -144,6 +150,9 @@ buildListOp counts o = buildOp' counts o []
|
|||
instance BuildOp ControlNode where
|
||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||
|
||||
instance BuildOp (ResourceHandle a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
instance BuildOp (Tensor Value a) where
|
||||
buildOp' = pureResult
|
||||
|
||||
|
@ -180,6 +189,9 @@ instance ( OpResult t1
|
|||
instance OpResult a => BuildOp (Build a) where
|
||||
buildOp' = buildResult
|
||||
|
||||
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
|
||||
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
||||
|
||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ module TensorFlow.Output
|
|||
, outputIndex
|
||||
, outputOp
|
||||
, PendingNodeName(..)
|
||||
, ResourceHandle(..)
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
|
@ -154,3 +155,8 @@ instance IsString Output where
|
|||
_ -> Output 0 $ assigned s
|
||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||
|
||||
|
||||
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||
-- resources are variables. The type parameter corresponds to the
|
||||
-- dtype of the tensor held in the variable.
|
||||
newtype ResourceHandle a = ResourceHandle Output
|
||||
|
|
2
third_party/tensorflow
vendored
2
third_party/tensorflow
vendored
|
@ -1 +1 @@
|
|||
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385
|
||||
Subproject commit e1c7e510a569cd5898f08015352bbdc8bef2ff7e
|
Loading…
Reference in a new issue