mirror of
https://github.com/tensorflow/haskell.git
synced 2024-11-17 16:39:43 +01:00
101 lines
3.5 KiB
Haskell
101 lines
3.5 KiB
Haskell
|
-- Copyright 2016 TensorFlow authors.
|
||
|
--
|
||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
-- you may not use this file except in compliance with the License.
|
||
|
-- You may obtain a copy of the License at
|
||
|
--
|
||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||
|
--
|
||
|
-- Unless required by applicable law or agreed to in writing, software
|
||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
-- See the License for the specific language governing permissions and
|
||
|
-- limitations under the License.
|
||
|
|
||
|
-- | Generates the wrappers for Ops shipped with tensorflow_c.
|
||
|
module Main where
|
||
|
|
||
|
import Distribution.Simple.BuildPaths (autogenModulesDir)
|
||
|
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
|
||
|
import Distribution.Simple
|
||
|
( defaultMainWithHooks
|
||
|
, simpleUserHooks
|
||
|
, UserHooks(..)
|
||
|
)
|
||
|
import Data.List (intercalate)
|
||
|
import Data.ProtoLens (decodeMessage)
|
||
|
import System.Directory (createDirectoryIfMissing)
|
||
|
import System.Exit (exitFailure)
|
||
|
import System.FilePath ((</>))
|
||
|
import System.IO (hPutStrLn, stderr)
|
||
|
import TensorFlow.Internal.FFI (getAllOpList)
|
||
|
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
|
||
|
import Text.PrettyPrint.Mainland (prettyLazyText)
|
||
|
import qualified Data.Text.Lazy.IO as Text
|
||
|
|
||
|
main = defaultMainWithHooks generatingOpsWrappers
|
||
|
|
||
|
-- TODO: Generalize for user libraries by replacing getAllOpList with
|
||
|
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
|
||
|
-- between bazel and Haskell build system.
|
||
|
generatingOpsWrappers :: UserHooks
|
||
|
generatingOpsWrappers = hooks
|
||
|
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
|
||
|
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
|
||
|
, replHook = \p l h f args -> generateSources l
|
||
|
>> replHook hooks p l h f args
|
||
|
}
|
||
|
where
|
||
|
flagsBuilder dir = OpGenFlags
|
||
|
{ outputFile = dir </> "Core.hs"
|
||
|
, prefix = "TensorFlow.GenOps"
|
||
|
, excludeList = intercalate "," blackList
|
||
|
}
|
||
|
hooks = simpleUserHooks
|
||
|
generateSources :: LocalBuildInfo -> IO ()
|
||
|
generateSources l = do
|
||
|
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
|
||
|
createDirectoryIfMissing True dir
|
||
|
let flags = flagsBuilder dir
|
||
|
pb <- getAllOpList
|
||
|
case decodeMessage pb of
|
||
|
Left e -> hPutStrLn stderr e >> exitFailure
|
||
|
Right x -> Text.writeFile (outputFile flags)
|
||
|
(prettyLazyText 80 $ docOpList flags x)
|
||
|
|
||
|
blackList =
|
||
|
-- A few data flow ops take a list of heterogeneous
|
||
|
-- parameters which we don't support in general form.
|
||
|
[ "HashTable"
|
||
|
, "MutableDenseHashTable"
|
||
|
, "MutableHashTable"
|
||
|
, "MutableHashTableOfTensors"
|
||
|
, "QueueDequeue"
|
||
|
, "QueueDequeueMany"
|
||
|
, "QueueDequeueUpTo"
|
||
|
, "Stack"
|
||
|
, "TensorArray"
|
||
|
-- These should be possible to support by adding a bunch of
|
||
|
-- overloads with a variable number of tuple arguments.
|
||
|
, "Assert"
|
||
|
, "BarrierTakeMany"
|
||
|
, "Print"
|
||
|
, "QueueEnqueue"
|
||
|
, "QueueEnqueueMany"
|
||
|
-- These have type ambiguities because one of the type arguments
|
||
|
-- doesn't appear in the signature.
|
||
|
, "ConditionalAccumulator"
|
||
|
, "SparseConditionalAccumulator"
|
||
|
-- Need list of types support.
|
||
|
, "DecodeCSV"
|
||
|
, "ParseExample"
|
||
|
, "ParseSingleSequenceExample"
|
||
|
, "Save"
|
||
|
, "SaveSlices"
|
||
|
, "SymbolicGradient"
|
||
|
, "_ArrayToList"
|
||
|
, "_ListToArray"
|
||
|
-- Easy: support larger result tuples.
|
||
|
, "Skipgram"
|
||
|
]
|