{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
module TensorFlow.Output
( ControlNode(..)
, Device(..)
, NodeName(..)
, OpDef(..)
, opName
, opType
, opAttr
, opInputs
, opControlInputs
, OpType(..)
, OutputIx(..)
, Output(..)
, output
, PendingNodeName(..)
) where
import qualified Data.Map.Strict as Map
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as Text
import Lens.Family2 (Lens')
import Lens.Family2.Unchecked (lens)
import Proto.Tensorflow.Core.Framework.AttrValue (AttrValue(..))
import Data.Default (def)
import TensorFlow.Types (Attribute, attrLens)
newtype ControlNode = ControlNode { unControlNode :: NodeName }
newtype OpType = OpType { unOpType :: Text }
deriving (Eq, Ord, Show)
instance IsString OpType where
fromString = OpType . Text.pack
data Output = Output {outputIndex :: !OutputIx, outputNodeName :: !NodeName}
deriving (Eq, Ord, Show)
output :: OutputIx -> NodeName -> Output
output = Output
newtype OutputIx = OutputIx { unOutputIx :: Int }
deriving (Eq, Ord, Num, Enum, Show)
newtype Device = Device {deviceName :: Text}
deriving (Eq, Ord, IsString)
instance Show Device where
show (Device d) = show d
data OpDef = OpDef
{ _opName :: !PendingNodeName
, _opType :: !OpType
, _opAttrs :: !(Map.Map Text AttrValue)
, _opInputs :: [Output]
, _opControlInputs :: [NodeName]
} deriving (Eq, Ord)
data PendingNodeName = ExplicitName !Text | ImplicitName
deriving (Eq, Ord, Show)
instance IsString PendingNodeName where
fromString = ExplicitName . fromString
newtype NodeName = NodeName { unNodeName :: Text }
deriving (Eq, Ord, Show)
opName :: Lens' OpDef PendingNodeName
opName = lens _opName (\o x -> o {_opName = x})
opType :: Lens' OpDef OpType
opType = lens _opType (\o x -> o { _opType = x})
opAttr :: Attribute a => Text -> Lens' OpDef a
opAttr n = lens _opAttrs (\o x -> o {_opAttrs = x})
. lens (Map.findWithDefault def n) (flip (Map.insert n))
. attrLens
opInputs :: Lens' OpDef [Output]
opInputs = lens _opInputs (\o x -> o {_opInputs = x})
opControlInputs :: Lens' OpDef [NodeName]
opControlInputs = lens _opControlInputs (\o x -> o {_opControlInputs = x})
instance IsString Output where
fromString s = case break (==':') s of
(n, ':':ixStr) | [(ix, "" :: String)] <- read ixStr
-> Output (fromInteger ix) $ assigned n
_ -> Output 0 $ assigned s
where assigned = NodeName . Text.pack