
307 lines
9.6 KiB

-- Copyright 2016 TensorFlow authors.
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
-- http://www.apache.org/licenses/LICENSE-2.0
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module TensorFlow.BuildOp
( BuildResult(..)
, buildOp
, PureResult(..)
, pureOp
, eqLengthGuard
, BuildInputs(..)
, OpParams
import Control.Monad (liftM2, replicateM)
import Control.Monad.Reader (ReaderT, runReaderT, ask)
import Control.Monad.State.Strict (State, evalState, get, put)
import Data.Int (Int64)
import TensorFlow.Build
import TensorFlow.Output
import TensorFlow.Tensor
import TensorFlow.Types
data ResultState = ResultState !OutputIx [Int64] deriving Show
type Result = ReaderT NodeName (State ResultState)
-- | Class of types that can be used as op outputs.
class BuildResult a where
buildResult :: Result a
instance (BuildResult a1, BuildResult a2) => BuildResult (a1, a2) where
buildResult = (,) <$> buildResult <*> buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3) => BuildResult (a1, a2, a3) where
buildResult = (,,) <$> buildResult <*> buildResult <*> buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4)
=> BuildResult (a1, a2, a3, a4) where
buildResult = (,,,) <$> buildResult <*> buildResult <*> buildResult <*> buildResult
instance (BuildResult a1, BuildResult a2, BuildResult a3, BuildResult a4, BuildResult a5)
=> BuildResult (a1, a2, a3, a4, a5) where
buildResult = (,,,,) <$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
=> BuildResult (a1, a2, a3, a4, a5, a6) where
buildResult = (,,,,,)
<$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
, BuildResult a7
=> BuildResult (a1, a2, a3, a4, a5, a6, a7) where
buildResult = (,,,,,,)
<$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
instance ( BuildResult a1
, BuildResult a2
, BuildResult a3
, BuildResult a4
, BuildResult a5
, BuildResult a6
, BuildResult a7
, BuildResult a8
=> BuildResult (a1, a2, a3, a4, a5, a6, a7, a8) where
buildResult = (,,,,,,,)
<$> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
<*> buildResult
recordResult :: Result Output
recordResult = do
o <- ask
ResultState i ns <- get
put $! ResultState (i+1) ns
return $! output i o
instance (TensorKind v, Rendered (Tensor v)) => BuildResult (Tensor v a) where
buildResult = Tensor . pure <$> recordResult
instance BuildResult ControlNode where
buildResult = ControlNode <$> ask
instance (TensorKind v, Rendered (Tensor v), TensorTypes as) => BuildResult (TensorList v as) where
buildResult = loop (tensorTypes :: TensorTypeList as)
loop :: TensorTypeList bs -> Result (TensorList v bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- buildResult
ts <- loop ls
return (t :/ ts)
instance BuildResult a => BuildResult [a] where
buildResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in buildResult. " ++
"Likely misuse of buildOp."
(n : rest) -> do
put $! ResultState i rest
replicateM (fromIntegral n) buildResult
buildOp :: BuildResult a => [Int64] -> OpDef -> Build a
buildOp sizes o = do
n <- addNewOp o
return $ flip evalState (ResultState 0 sizes) (runReaderT buildResult n)
-- | Returns true if all the integers in each tuple are identical.
-- Throws an error with a descriptive message if not.
eqLengthGuard :: [(String, [(String, Int)])] -> Bool
eqLengthGuard = all eachOk
eachOk (_, []) = True
-- The next line has (== 1) . length . nub in disguise
eachOk (numberAttrName, pairs@((_, x) : zs)) = all (\z -> snd z == x) zs ||
error ("number_attr " ++ numberAttrName ++
" contains tensors with different length " ++ show pairs)
-- | Class of types that can be used as op outputs.
class PureResult a where
pureResult :: ReaderT (Build OpDef) (State ResultState) a
instance PureResult (Tensor Build a) where
pureResult = do
ResultState i ns <- get
put $! ResultState (i+1) ns
makeOp <- ask
return $ Tensor $ do
o <- makeOp
-- TODO: unify with BuildResult (Tensor v)
output i <$> getOrAddOp o
instance (PureResult a1, PureResult a2) => PureResult (a1, a2) where
pureResult = (,) <$> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3) => PureResult (a1, a2, a3) where
pureResult = (,,) <$> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4)
=> PureResult (a1, a2, a3, a4) where
pureResult = (,,,) <$> pureResult <*> pureResult <*> pureResult <*> pureResult
instance (PureResult a1, PureResult a2, PureResult a3, PureResult a4, PureResult a5)
=> PureResult (a1, a2, a3, a4, a5) where
pureResult = (,,,,) <$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
=> PureResult (a1, a2, a3, a4, a5, a6) where
pureResult = (,,,,,)
<$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
, PureResult a7
=> PureResult (a1, a2, a3, a4, a5, a6, a7) where
pureResult = (,,,,,,)
<$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance ( PureResult a1
, PureResult a2
, PureResult a3
, PureResult a4
, PureResult a5
, PureResult a6
, PureResult a7
, PureResult a8
=> PureResult (a1, a2, a3, a4, a5, a6, a7, a8) where
pureResult = (,,,,,,,)
<$> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
<*> pureResult
instance PureResult a => PureResult [a] where
pureResult = do
ResultState i ns <- get
case ns of
[] -> error $ "Ran out of counts in pureResult. " ++
"Likely misuse of pureOp with output lists."
n : rest -> do
put $! ResultState i rest
replicateM (fromIntegral n) pureResult
instance TensorTypes as => PureResult (TensorList Build as) where
pureResult = loop (tensorTypes :: TensorTypeList as)
loop :: TensorTypeList bs -> ReaderT (Build OpDef) (State ResultState)
(TensorList Build bs)
loop Nil = return Nil
loop (TensorTypeProxy :/ ls) = do
t <- pureResult
ts <- loop ls
return (t :/ ts)
pureOp :: PureResult a => [Int64] -> Build OpDef -> a
pureOp sizes o = flip evalState (ResultState 0 sizes) (runReaderT pureResult o)
-- Class of types that can be used as arguments
class BuildInputs a where
buildInputs :: a -> Build [Output]
instance BuildInputs a => BuildInputs [a] where
buildInputs = fmap concat . mapM buildInputs
instance BuildInputs (Tensor v a) where
buildInputs (Tensor t) = do
o <- toBuild t
return [o]
instance BuildInputs (ListOf (Tensor v) as) where
buildInputs Nil = return []
buildInputs (t :/ ts) = liftM2 (++) (buildInputs t) (buildInputs ts)
-- | Parameters to build an op (for example, the node name or optional attributes).
-- TODO: be more type safe.
type OpParams = OpDef -> OpDef