1
0
Fork 0
mirror of https://github.com/tensorflow/haskell.git synced 2024-12-28 12:39:46 +01:00
tensorflow-haskell/tensorflow-core-ops/Setup.hs
2023-01-20 09:50:24 +01:00

168 lines
5.2 KiB
Haskell

-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
{-# LANGUAGE CPP #-}
-- | Generates the wrappers for Ops shipped with tensorflow.
module Main where
import Distribution.PackageDescription
( PackageDescription(..)
, libBuildInfo
, hsSourceDirs
)
import qualified Distribution.Simple.BuildPaths as BuildPaths
import Distribution.Simple.LocalBuildInfo (LocalBuildInfo)
import Distribution.Simple
( defaultMainWithHooks
, simpleUserHooks
, UserHooks(..)
)
import Data.List (intercalate)
import Data.ProtoLens (decodeMessage)
import System.Directory (createDirectoryIfMissing)
import System.Exit (exitFailure)
import System.FilePath ((</>))
import System.IO (hPutStrLn, stderr)
import TensorFlow.Internal.FFI (getAllOpList)
import TensorFlow.OpGen (docOpList, OpGenFlags(..))
import Text.PrettyPrint.Mainland (prettyLazyText)
import qualified Data.Text.Lazy.IO as Text
main = defaultMainWithHooks generatingOpsWrappers
-- TODO: Generalize for user libraries by replacing getAllOpList with
-- a wrapper around TF_LoadLibrary. The complicated part is interplay
-- between bazel and Haskell build system.
generatingOpsWrappers :: UserHooks
generatingOpsWrappers = hooks
{ buildHook = \p l h f -> generateSources l >> buildHook hooks p l h f
, haddockHook = \p l h f -> generateSources l >> haddockHook hooks p l h f
, replHook = \p l h f args -> generateSources l
>> replHook hooks p l h f args
}
where
flagsBuilder dir = OpGenFlags
{ outputFile = dir </> "Core.hs"
, prefix = "TensorFlow.GenOps"
, excludeList = intercalate "," blackList
}
hooks = simpleUserHooks
generateSources :: LocalBuildInfo -> IO ()
generateSources l = do
let dir = autogenModulesDir l </> "TensorFlow/GenOps"
createDirectoryIfMissing True dir
let flags = flagsBuilder dir
pb <- getAllOpList
case decodeMessage pb of
Left e -> hPutStrLn stderr e >> exitFailure
Right x -> Text.writeFile (outputFile flags)
(prettyLazyText 80 $ docOpList flags x)
-- | Add the autogen directory to the hs-source-dirs of all the targets in the
-- .cabal file. Used to fool 'sdist' by pointing it to the generated source
-- files.
fudgePackageDesc
:: LocalBuildInfo -> PackageDescription -> PackageDescription
fudgePackageDesc lbi p = p
{ library =
(\lib -> lib { libBuildInfo = fudgeBuildInfo (libBuildInfo lib) })
<$> library p
}
where
fudgeBuildInfo bi =
bi { hsSourceDirs = autogenModulesDir lbi : hsSourceDirs bi }
blackList =
[ -- Requires the "func" type:
"FilterDataset"
, "BatchFunction"
, "Case"
, "ChooseFastestBranchDataset"
, "ExperimentalGroupByReducerDataset"
, "ExperimentalGroupByWindowDataset"
, "ExperimentalMapAndBatchDataset"
, "ExperimentalMapDataset"
, "ExperimentalNumaMapAndBatchDataset"
, "ExperimentalParallelInterleaveDataset"
, "ExperimentalScanDataset"
, "ExperimentalTakeWhileDataset"
, "FilterDataset"
, "FlatMapDataset"
, "For"
, "GeneratorDataset"
, "GroupByReducerDataset"
, "GroupByWindowDataset"
, "If"
, "InterleaveDataset"
, "LegacyParallelInterleaveDatasetV2"
, "LoadDataset"
, "MapAndBatchDataset"
, "MapAndBatchDatasetV2"
, "MapDataset"
, "MapDefun"
, "OneShotIterator"
, "ParallelInterleaveDataset"
, "ParallelInterleaveDatasetV2"
, "ParallelInterleaveDatasetV3"
, "ParallelInterleaveDatasetV4"
, "ParallelMapDataset"
, "ParallelMapDatasetV2"
, "ParseSequenceExample"
, "ParseSequenceExampleV2"
, "ParseSingleSequenceExample"
, "PartitionedCall"
, "ReduceDataset"
, "RemoteCall"
, "SaveDataset"
, "ScanDataset"
, "SnapshotDatasetV2"
, "StatefulPartitionedCall"
, "StatelessCase"
, "StatelessIf"
, "StatelessWhile"
, "SymbolicGradient"
, "TakeWhileDataset"
, "TPUCompile"
, "TPUPartitionedCall"
, "TPUReplicate"
, "While"
, "XlaIf"
, "XlaLaunch"
, "XlaReduce"
, "XlaReduceWindow"
, "XlaSelectAndScatter"
, "XlaScatter"
, "XlaWhile"
, "_If"
, "_TPUReplicate"
, "_While"
, "_XlaCompile"
-- Incorrectly generated:
, "_FusedBatchNormGradEx"
-- Could not deduce:
, "_MklFusedBatchNorm"
, "_MklFusedBatchNormEx"
, "_MklFusedBatchNormGrad"
, "_MklFusedBatchNormGradV2"
, "_MklFusedBatchNormGradV3"
, "_MklFusedBatchNormV2"
, "_MklFusedBatchNormV3"
]
autogenModulesDir :: LocalBuildInfo -> FilePath
#if MIN_VERSION_Cabal(2,0,0)
autogenModulesDir = BuildPaths.autogenPackageModulesDir
#else
autogenModulesDir = BuildPaths.autogenModulesDir
#endif