mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-30 06:49:44 +01:00
fb629d1207
Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and pass an input map for all existing nodes in the graph.
127 lines
4.1 KiB
Haskell
127 lines
4.1 KiB
Haskell
-- Copyright 2016 TensorFlow authors.
|
|
--
|
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
|
-- you may not use this file except in compliance with the License.
|
|
-- You may obtain a copy of the License at
|
|
--
|
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
|
--
|
|
-- Unless required by applicable law or agreed to in writing, software
|
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
-- See the License for the specific language governing permissions and
|
|
-- limitations under the License.
|
|
|
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE Rank2Types #-}
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
module TensorFlow.Output
|
|
( ControlNode(..)
|
|
, Device(..)
|
|
-- * Ops
|
|
, NodeName(..)
|
|
, OpDef(..)
|
|
, opName
|
|
, opType
|
|
, opAttr
|
|
, opInputs
|
|
, opControlInputs
|
|
, OpType(..)
|
|
, OutputIx(..)
|
|
, Output(..)
|
|
, output
|
|
, PendingNodeName(..)
|
|
) where
|
|
|
|
import Data.ProtoLens.Message(defMessage)
|
|
import qualified Data.Map.Strict as Map
|
|
import Data.String (IsString(..))
|
|
import Data.Text (Text)
|
|
import qualified Data.Text as Text
|
|
import Lens.Family2 (Lens')
|
|
import Lens.Family2.Unchecked (lens)
|
|
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
|
|
import TensorFlow.Types (Attribute, attrLens)
|
|
|
|
-- | A type of graph node which has no outputs. These nodes are
|
|
-- valuable for causing side effects when they are run.
|
|
newtype ControlNode = ControlNode { unControlNode :: NodeName }
|
|
|
|
-- | The type of op of a node in the graph. This corresponds to the proto field
|
|
-- NodeDef.op.
|
|
newtype OpType = OpType { unOpType :: Text }
|
|
deriving (Eq, Ord, Show)
|
|
|
|
instance IsString OpType where
|
|
fromString = OpType . Text.pack
|
|
|
|
-- | An output of a TensorFlow node.
|
|
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
|
|
deriving (Eq, Ord, Show)
|
|
|
|
output :: OutputIx -> NodeName -> Output
|
|
output = Output
|
|
|
|
newtype OutputIx = OutputIx { unOutputIx :: Int }
|
|
deriving (Eq, Ord, Num, Enum, Show)
|
|
|
|
-- | A device that a node can be assigned to.
|
|
-- There's a naming convention where the device names
|
|
-- are constructed from job and replica names.
|
|
newtype Device = Device {deviceName :: Text}
|
|
deriving (Eq, Ord, IsString)
|
|
|
|
instance Show Device where
|
|
show (Device d) = show d
|
|
|
|
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
|
|
data OpDef = OpDef
|
|
{ _opName :: !PendingNodeName
|
|
, _opType :: !OpType
|
|
, _opAttrs :: !(Map.Map Text AttrValue)
|
|
, _opInputs :: [Output]
|
|
, _opControlInputs :: [NodeName]
|
|
} deriving (Eq, Ord)
|
|
|
|
-- | The name specified for an unrendered Op. If an Op has an
|
|
-- ImplicitName, it will be assigned based on the opType plus a
|
|
-- unique identifier. Does not contain the "scope" prefix.
|
|
data PendingNodeName = ExplicitName !Text | ImplicitName
|
|
deriving (Eq, Ord, Show)
|
|
|
|
instance IsString PendingNodeName where
|
|
fromString = ExplicitName . fromString
|
|
|
|
-- | The name of a node in the graph. This corresponds to the proto field
|
|
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
|
|
-- (if the node was implicitly named).
|
|
newtype NodeName = NodeName { unNodeName :: Text }
|
|
deriving (Eq, Ord, Show)
|
|
|
|
opName :: Lens' OpDef PendingNodeName
|
|
opName = lens _opName (\o x -> o {_opName = x})
|
|
|
|
opType :: Lens' OpDef OpType
|
|
opType = lens _opType (\o x -> o { _opType = x})
|
|
|
|
opAttr :: Attribute a => Text -> Lens' OpDef a
|
|
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
|
|
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
|
|
. attrLens
|
|
|
|
opInputs :: Lens' OpDef [Output]
|
|
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
|
|
|
|
opControlInputs :: Lens' OpDef [NodeName]
|
|
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
|
|
|
|
-- TODO(gnezdo): IsString instance is weird and we should move that
|
|
-- code into a Build function
|
|
instance IsString Output where
|
|
fromString s = case break (==':') s of
|
|
(n, ':':ixStr) | ix <- read ixStr
|
|
-> Output (fromInteger ix) $ assigned n
|
|
_ -> Output 0 $ assigned s
|
|
where assigned = NodeName . Text.pack
|