1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-11-30 06:49:44 +01:00
tensorflow-haskell/tensorflow/src/TensorFlow/Output.hs
Bart Schuurmans fb629d1207 Migrate from TF_DeprecatedSession to TF_Session
Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and
pass an input map for all existing nodes in the graph.
2023-02-04 15:18:03 -08:00

127 lines
4.1 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Output
( ControlNode(..)
, Device(..)
-- * Ops
, NodeName(..)
, OpDef(..)
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, OpType(..)
, OutputIx(..)
, Output(..)
, output
, PendingNodeName(..)
) where
import Data.ProtoLens.Message(defMessage)
import qualified Data.Map.Strict as Map
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens')
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue)
import TensorFlow.Types (Attribute, attrLens)
-- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: NodeName }
-- | The type of op of a node in the graph. This corresponds to the proto field
-- NodeDef.op.
newtype OpType = OpType { unOpType :: Text }
deriving (Eq, Ord, Show)
instance IsString OpType where
fromString = OpType . Text.pack
-- | An output of a TensorFlow node.
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
deriving (Eq, Ord, Show)
output :: OutputIx -> NodeName -> Output
output = Output
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
-- | A device that a node can be assigned to.
-- There's a naming convention where the device names
-- are constructed from job and replica names.
newtype Device = Device {deviceName :: Text}
deriving (Eq, Ord, IsString)
instance Show Device where
show (Device d) = show d
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef
{ _opName :: !PendingNodeName
, _opType :: !OpType
, _opAttrs :: !(Map.Map Text AttrValue)
, _opInputs :: [Output]
, _opControlInputs :: [NodeName]
} deriving (Eq, Ord)
-- | The name specified for an unrendered Op. If an Op has an
-- ImplicitName, it will be assigned based on the opType plus a
-- unique identifier. Does not contain the "scope" prefix.
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
instance IsString PendingNodeName where
fromString = ExplicitName . fromString
-- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).
newtype NodeName = NodeName { unNodeName :: Text }
deriving (Eq, Ord, Show)
opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})
opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault defMessage n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- TODO(gnezdo): IsString instance is weird and we should move that
-- code into a Build function
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr) | ix <- read ixStr
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack