tensorflow-haskell/tensorflow/src/TensorFlow/Output.hs

157 lines
5.1 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE OverloadedStrings #-}
module TensorFlow.Output
( ControlNode(..)
, Device(..)
-- * Ops
, NodeName(..)
, Op(..)
, opUnrendered
, OpDef(..)
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, OpType(..)
, OutputIx(..)
, Output(..)
, output
, outputIndex
, outputOp
, PendingNodeName(..)
) where
import qualified Data.Map.Strict as Map
import Data.ProtoLens.TextFormat (showMessage)
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens', Traversal', (.~), (&), (^.))
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Proto.Tensorflow.Core.Framework.NodeDef (NodeDef(..), name)
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
import TensorFlow.Orphans ()
-- | A type of graph node which has no outputs. These nodes are
-- valuable for causing side effects when they are run.
newtype ControlNode = ControlNode { unControlNode :: Op }
-- | The type of op of a node in the graph. This corresponds to the proto field
-- NodeDef.op.
newtype OpType = OpType { unOpType :: Text }
deriving (Eq, Ord, Show)
instance IsString OpType where
fromString = OpType . Text.pack
-- | An output of a TensorFlow node.
data Output = Output !OutputIx !Op
deriving (Eq, Ord, Show)
output :: OutputIx -> Op -> Output
output = Output
outputOp :: Lens' Output Op
outputOp = lens (\(Output _ o) -> o) (\(Output i _) o -> Output i o)
outputIndex :: Lens' Output OutputIx
outputIndex = lens (\(Output i _) -> i) (\(Output _ o) i -> Output i o)
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
-- | A device that a node can be assigned to.
-- There's a naming convention where the device names
-- are constructed from job and replica names.
newtype Device = Device {deviceName :: Text}
deriving (Eq, Ord, IsString)
instance Show Device where
show (Device d) = show d
-- | The representation of a node in a TensorFlow graph.
data Op
= Rendered !NodeDef -- ^ Properties are fixed, including the
-- device, name, and scope.
| Unrendered !OpDef -- ^ Properties are not fixed, and may change depending
-- on which context this op is rendered in.
deriving (Eq, Ord)
instance Show Op where
show (Rendered n) = "Rendered " ++ showMessage n
show (Unrendered o) = "Unrendered " ++ show (o ^. opName)
-- | Traverse on the 'Unrendered' of an 'Op'.
--
-- Same implementation as _Left.
opUnrendered :: Traversal' Op OpDef
opUnrendered f (Unrendered a) = Unrendered <$> f a
opUnrendered _ (Rendered b) = pure (Rendered b)
-- | Op definition. This corresponds somewhat to the 'NodeDef' proto.
data OpDef = OpDef
{ _opName :: !PendingNodeName
, _opType :: !OpType
, _opAttrs :: !(Map.Map Text AttrValue)
, _opInputs :: [Output]
, _opControlInputs :: [NodeName]
} deriving (Eq, Ord)
-- | The name specified for an unrendered Op. If an Op has an
-- ImplicitName, it will be assigned based on the opType plus a
-- unique identifier. Does not contain the "scope" prefix.
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
-- | The name of a node in the graph. This corresponds to the proto field
-- NodeDef.name. Includes the scope prefix (if any) and a unique identifier
-- (if the node was implicitly named).
newtype NodeName = NodeName { unNodeName :: Text }
deriving (Eq, Ord, Show)
opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})
opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault def n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
-- TODO(gnezdo): IsString instance is weird and we should move that
-- code into a Build function
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr)
| [(ix, "")] <- read ixStr -> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned n = Rendered $ def & name .~ Text.pack n