mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-26 21:09:44 +01:00
parent
29f11d351d
commit
8db944578a
4 changed files with 67 additions and 17 deletions
|
@ -25,10 +25,11 @@ module TensorFlow.OpGen
|
||||||
|
|
||||||
import Prelude hiding (head, tail)
|
import Prelude hiding (head, tail)
|
||||||
|
|
||||||
|
import Control.Applicative ((<**>))
|
||||||
import Control.Monad (guard)
|
import Control.Monad (guard)
|
||||||
import Data.Char (toLower, toUpper)
|
import Data.Char (toLower, toUpper)
|
||||||
import Data.Foldable (toList)
|
import Data.Foldable (toList)
|
||||||
import Data.Maybe (fromMaybe, maybeToList)
|
import Data.Maybe (catMaybes, fromMaybe, maybeToList)
|
||||||
import Data.ProtoLens (def, showMessage)
|
import Data.ProtoLens (def, showMessage)
|
||||||
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
import Data.List.NonEmpty (NonEmpty((:|)), head)
|
||||||
import qualified Data.List.NonEmpty as NE
|
import qualified Data.List.NonEmpty as NE
|
||||||
|
@ -119,6 +120,10 @@ docOpList flags opList =
|
||||||
, "{-# LANGUAGE OverloadedStrings #-}"
|
, "{-# LANGUAGE OverloadedStrings #-}"
|
||||||
, "{-# LANGUAGE RankNTypes #-}"
|
, "{-# LANGUAGE RankNTypes #-}"
|
||||||
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
, "{-# LANGUAGE ScopedTypeVariables #-}"
|
||||||
|
-- Avoids reports about shadowing standard library names.
|
||||||
|
, "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}"
|
||||||
|
-- eqLengthGuard never returns false and dies instead.
|
||||||
|
, "{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}"
|
||||||
, "module" <+> strictText moduleName <+> "where"
|
, "module" <+> strictText moduleName <+> "where"
|
||||||
, empty
|
, empty
|
||||||
, imports
|
, imports
|
||||||
|
@ -135,6 +140,7 @@ docOpList flags opList =
|
||||||
shortName = Text.pack (takeBaseName $ outputFile flags)
|
shortName = Text.pack (takeBaseName $ outputFile flags)
|
||||||
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
exclusions = Text.splitOn "," $ Text.pack $ excludeList flags
|
||||||
|
|
||||||
|
camelCase :: Text -> Text
|
||||||
camelCase s = Text.concat $ map upCase
|
camelCase s = Text.concat $ map upCase
|
||||||
$ filter (/= "ops")
|
$ filter (/= "ops")
|
||||||
$ Text.splitOn "_" s
|
$ Text.splitOn "_" s
|
||||||
|
@ -152,6 +158,7 @@ forceCase :: (Char -> Char) -> Text -> Text
|
||||||
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
|
||||||
(Text.uncons s)
|
(Text.uncons s)
|
||||||
|
|
||||||
|
imports :: Doc
|
||||||
imports = stack [
|
imports = stack [
|
||||||
"import Data.ByteString (ByteString)"
|
"import Data.ByteString (ByteString)"
|
||||||
, "import Data.Complex (Complex)"
|
, "import Data.Complex (Complex)"
|
||||||
|
@ -160,6 +167,7 @@ imports = stack [
|
||||||
, "import Lens.Family2 ((.~), (&))"
|
, "import Lens.Family2 ((.~), (&))"
|
||||||
, "import TensorFlow.Build"
|
, "import TensorFlow.Build"
|
||||||
, "import TensorFlow.BuildOp"
|
, "import TensorFlow.BuildOp"
|
||||||
|
, "import TensorFlow.Output (ResourceHandle)"
|
||||||
, "import TensorFlow.Tensor"
|
, "import TensorFlow.Tensor"
|
||||||
, "import TensorFlow.Types"
|
, "import TensorFlow.Types"
|
||||||
]
|
]
|
||||||
|
@ -237,40 +245,57 @@ quotedText = dquotes . strictText
|
||||||
typeSig :: OpDef -> Doc
|
typeSig :: OpDef -> Doc
|
||||||
typeSig d =
|
typeSig d =
|
||||||
foralls <+> constraints <+/>
|
foralls <+> constraints <+/>
|
||||||
signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
|
signatureFold (mandatoryAttrInputs ++ map snd tensorInputs ++ [outputs])
|
||||||
where
|
where
|
||||||
foralls | Map.null typeMap = empty
|
foralls | null typeMap = empty
|
||||||
| otherwise =
|
| otherwise =
|
||||||
"forall"
|
"forall"
|
||||||
<+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
|
<+> sep (refVariableNames ++ typeMapTypeNames)
|
||||||
<+> "."
|
<+> "."
|
||||||
constraints | Map.null typeMap = empty
|
typeMapTypeNames = map (strictText . fst) (Map.elems typeMap)
|
||||||
|
constraints | null typeMap = empty
|
||||||
| otherwise =
|
| otherwise =
|
||||||
tuple (concatMap
|
tuple (concatMap
|
||||||
(\(t, aDef) ->
|
(\(t, aDef) ->
|
||||||
"TensorType" <+> strictText t
|
"TensorType" <+> strictText t
|
||||||
: maybeToList (oneOfRestrictions aDef t))
|
: maybeToList (oneOfRestrictions aDef t))
|
||||||
(Map.elems typeMap)) <+> "=>"
|
(Map.elems typeMap)) <+> "=>"
|
||||||
|
refVariableNames = catMaybes (map fst tensorInputs)
|
||||||
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
|
||||||
refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
|
refTypes = ["v" <> int x | x <- [1..length (d ^. inputArg)]]
|
||||||
tensorArg refType arg = wrapArg refType arg <+>
|
tensorArg refType arg = wrapArg refType arg <**>
|
||||||
hang 0 ("-- ^" <+> argComment arg)
|
pure (<+> hang 0 ("-- ^" <+> argComment arg))
|
||||||
-- Argument type is a list of tensors if number_attr is set;
|
-- Argument type is a list of tensors if number_attr is set;
|
||||||
-- otherwise it's a single Tensor.
|
-- otherwise it's a single Tensor.
|
||||||
wrapArg refType arg =
|
wrapArg refType arg =
|
||||||
if Text.null (arg ^. numberAttr) then typ else brackets typ
|
if Text.null (arg ^. numberAttr) then typ else brackets <$> typ
|
||||||
where typ = tensorType refType arg
|
where typ = tensorType refType arg
|
||||||
tensorType refType arg =
|
-- The result is (reference type variable if any, type representing the arg)
|
||||||
"Tensor" <+> refType <+> maybe directType strictText indirectType
|
tensorType :: Doc -> OpDef'ArgDef -> (Maybe Doc, Doc)
|
||||||
|
tensorType refType arg
|
||||||
|
-- Deals with resource handles that are unlike tensors and
|
||||||
|
-- have their own type level representation. The magic "dtype"
|
||||||
|
-- name is the name of the attribute specifying the type of
|
||||||
|
-- the resource handle. It is not referenced by resource input
|
||||||
|
-- or output because it isn't expressible in TF operation
|
||||||
|
-- signatures.
|
||||||
|
| (arg ^. type' == DT_RESOURCE) =
|
||||||
|
(Nothing, strictText "ResourceHandle dtype")
|
||||||
|
| otherwise =
|
||||||
|
(Just refType,
|
||||||
|
"Tensor" <+> refType <+> maybe directType strictText indirectType)
|
||||||
where
|
where
|
||||||
indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
|
-- This is the case of a named parameter type which is
|
||||||
|
-- constrained as a OneOf.
|
||||||
|
indirectType = fst <$> (Map.lookup (arg ^. typeAttr) typeMap)
|
||||||
|
-- The nominal case when the type name is given directly.
|
||||||
directType = dtTypeToDoc (arg ^. type')
|
directType = dtTypeToDoc (arg ^. type')
|
||||||
outputs =
|
outputs =
|
||||||
case d ^. outputArg of
|
case d ^. outputArg of
|
||||||
[] -> "ControlNode"
|
[] -> "ControlNode"
|
||||||
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
[o] -> wrappedOutput o <+> "-- ^" <+> argComment o
|
||||||
os -> renderTupleResult os
|
os -> renderTupleResult os
|
||||||
wrappedOutput = wrapArg "Value"
|
wrappedOutput = snd . wrapArg "Value"
|
||||||
-- Tuple result case is rendered differently to give
|
-- Tuple result case is rendered differently to give
|
||||||
-- individual elements their own comments.
|
-- individual elements their own comments.
|
||||||
renderTupleResult os =
|
renderTupleResult os =
|
||||||
|
@ -324,7 +349,11 @@ argComment' argName argDesc =
|
||||||
bold :: Text.Text -> Doc
|
bold :: Text.Text -> Doc
|
||||||
bold n = strictText ("__" <> n <> "__")
|
bold n = strictText ("__" <> n <> "__")
|
||||||
|
|
||||||
opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
|
type OpDefTypeMap = Map.Map Text.Text (Text.Text, AttrDef)
|
||||||
|
|
||||||
|
-- | Returns the map of type parameters from OpDef type name to
|
||||||
|
-- (haskell friendly type name, the type's attribute definition).
|
||||||
|
opDefTypeMap :: OpDef -> OpDefTypeMap
|
||||||
opDefTypeMap d =
|
opDefTypeMap d =
|
||||||
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]
|
||||||
|
|
||||||
|
@ -390,6 +419,8 @@ dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
|
||||||
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
|
||||||
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
dtTypeToHaskell DT_HALF = "Data.Word.Word16" -- TODO(gnezdo): make unique
|
||||||
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
|
||||||
|
dtTypeToHaskell DT_RESOURCE =
|
||||||
|
error "ResourceHandle must be prevented from getting here."
|
||||||
dtTypeToHaskell x =
|
dtTypeToHaskell x =
|
||||||
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x
|
||||||
|
|
||||||
|
@ -419,6 +450,7 @@ replaceReservedName n
|
||||||
| n `Set.member` reservedKeywords = n <> "'"
|
| n `Set.member` reservedKeywords = n <> "'"
|
||||||
| otherwise = n
|
| otherwise = n
|
||||||
|
|
||||||
|
indentation :: Int
|
||||||
indentation = 4
|
indentation = 4
|
||||||
|
|
||||||
reservedKeywords :: Set.Set Text
|
reservedKeywords :: Set.Set Text
|
||||||
|
|
|
@ -77,11 +77,17 @@ instance ( OpResult a1
|
||||||
<*> toResult
|
<*> toResult
|
||||||
|
|
||||||
tensorResult :: TensorKind v -> Result (Tensor v a)
|
tensorResult :: TensorKind v -> Result (Tensor v a)
|
||||||
tensorResult v = do
|
tensorResult v = Tensor v <$> recordResult
|
||||||
|
|
||||||
|
recordResult :: Result Output
|
||||||
|
recordResult = do
|
||||||
o <- ask
|
o <- ask
|
||||||
ResultState i ns <- get
|
ResultState i ns <- get
|
||||||
put $! ResultState (i+1) ns
|
put $! ResultState (i+1) ns
|
||||||
return $! Tensor v $ output i o
|
return $! output i o
|
||||||
|
|
||||||
|
instance OpResult (ResourceHandle a) where
|
||||||
|
toResult = ResourceHandle <$> recordResult
|
||||||
|
|
||||||
instance OpResult (Tensor Value a) where
|
instance OpResult (Tensor Value a) where
|
||||||
toResult = tensorResult ValueKind
|
toResult = tensorResult ValueKind
|
||||||
|
@ -144,6 +150,9 @@ buildListOp counts o = buildOp' counts o []
|
||||||
instance BuildOp ControlNode where
|
instance BuildOp ControlNode where
|
||||||
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
buildOp' _ o ts = ControlNode $ Unrendered $ addReversedInputs o ts
|
||||||
|
|
||||||
|
instance BuildOp (ResourceHandle a) where
|
||||||
|
buildOp' = pureResult
|
||||||
|
|
||||||
instance BuildOp (Tensor Value a) where
|
instance BuildOp (Tensor Value a) where
|
||||||
buildOp' = pureResult
|
buildOp' = pureResult
|
||||||
|
|
||||||
|
@ -180,6 +189,9 @@ instance ( OpResult t1
|
||||||
instance OpResult a => BuildOp (Build a) where
|
instance OpResult a => BuildOp (Build a) where
|
||||||
buildOp' = buildResult
|
buildOp' = buildResult
|
||||||
|
|
||||||
|
instance BuildOp f => BuildOp (ResourceHandle a -> f) where
|
||||||
|
buildOp' rf o ts (ResourceHandle t) = buildOp' rf o (t : ts)
|
||||||
|
|
||||||
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
instance BuildOp f => BuildOp (Tensor v a -> f) where
|
||||||
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
buildOp' rf o ts t = buildOp' rf o (t ^. tensorOutput : ts)
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ module TensorFlow.Output
|
||||||
, outputIndex
|
, outputIndex
|
||||||
, outputOp
|
, outputOp
|
||||||
, PendingNodeName(..)
|
, PendingNodeName(..)
|
||||||
|
, ResourceHandle(..)
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
|
@ -154,3 +155,8 @@ instance IsString Output where
|
||||||
_ -> Output 0 $ assigned s
|
_ -> Output 0 $ assigned s
|
||||||
where assigned n = Rendered $ def & name .~ Text.pack n
|
where assigned n = Rendered $ def & name .~ Text.pack n
|
||||||
|
|
||||||
|
|
||||||
|
-- | Opaque handle to a mutable resource in the graph. Typical such
|
||||||
|
-- resources are variables. The type parameter corresponds to the
|
||||||
|
-- dtype of the tensor held in the variable.
|
||||||
|
newtype ResourceHandle a = ResourceHandle Output
|
||||||
|
|
2
third_party/tensorflow
vendored
2
third_party/tensorflow
vendored
|
@ -1 +1 @@
|
||||||
Subproject commit bac7faa9a3eb5b60687a83336202cd3493de5385
|
Subproject commit e1c7e510a569cd5898f08015352bbdc8bef2ff7e
|
Loading…
Reference in a new issue