-- Copyright 2016 TensorFlow authors.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
-- | Rendering of TensorFlow operations as Haskell functions.

module TensorFlow.OpGen
  ( OpGenFlags(..)
  , docOpList
  , flagParser)
  where

import Prelude hiding (head, tail)

import Control.Monad (guard)
import Data.Char (toLower, toUpper)
import Data.Foldable (toList)
import Data.Maybe (fromMaybe, maybeToList)
import Data.ProtoLens (def, showMessage)
import Data.List.NonEmpty (NonEmpty((:|)), head)
import qualified Data.List.NonEmpty as NE
import Lens.Family2 ((^.), (.~), (&), view)
import Options.Applicative (Parser, help, long, strOption, value)
import Proto.Tensorflow.Core.Framework.OpDef
  ( OpList
  , OpDef
  , OpDef'ArgDef
  , attr
  , description
  , inputArg
  , name
  , numberAttr
  , op
  , outputArg
  , summary
  , type'
  , typeAttr
  )
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
import System.FilePath (takeBaseName)
import TensorFlow.OpGen.AttrVal
  (AttrDef
  , AttrCase(..)
  , AttrTemplate(..)
  , Template
  , attrDef
  , attrOriginal
  , attrTemplate
  , templateDefault
  , templateRestrictions
  )
import Text.PrettyPrint.Mainland
  ( Doc
  , (<>)
  , (<+>)
  , (</>)
  , (<+/>)
  , brackets
  , comma
  , commasep
  , dquotes
  , empty
  , enclose
  , flatten
  , folddoc
  , hang
  , indent
  , int
  , parens
  , sep
  , stack
  , strictText
  , tuple
  )
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as Text
import qualified Data.Semigroup as Semigroup
import Data.Text (Text)

data OpGenFlags = OpGenFlags
     { outputFile :: String
     , prefix :: String
     , excludeList :: String
     }

flagParser :: Parser OpGenFlags
flagParser = OpGenFlags
     <$> strOption (mconcat [ long "output"
                            , help "File to write."
                            ])
     <*> strOption (mconcat [ long "prefix"
                            , help "Haskell package prefix to use"
                            ])
     <*> strOption (mconcat [ long "exclude_list"
                            , value ""
                            , help "Comma separated Ops names to ignore"
                            ])


docOpList :: OpGenFlags -> OpList -> Doc
docOpList flags opList =
  stack [ "{-# LANGUAGE ConstraintKinds #-}"
        , "{-# LANGUAGE DataKinds #-}"
        , "{-# LANGUAGE FlexibleInstances #-}"
        , "{-# LANGUAGE OverloadedStrings #-}"
        , "{-# LANGUAGE RankNTypes #-}"
        , "{-# LANGUAGE ScopedTypeVariables #-}"
        , "module" <+> strictText moduleName <+> "where"
        , empty
        , imports
        , empty
        , folddoc (\x y -> x </> empty </> y)
                  (map renderDef $
                   filter (not . flip elem exclusions . view name) $
                   toList $ opList ^. op)
        ]
  where moduleName =
            Text.pack (prefix flags) <> "." <> camelCase
             -- Discards the optional trailing _op_lib
            (fromMaybe shortName (Text.stripSuffix "_op_lib" shortName))
        shortName = Text.pack (takeBaseName $ outputFile flags)
        exclusions = Text.splitOn "," $ Text.pack $ excludeList flags

camelCase s = Text.concat $ map upCase
                          $ filter (/= "ops")
                          $ Text.splitOn "_" s

-- | Upper-case the given text.
upCase :: Text -> Text
upCase = forceCase toUpper

-- | Lower-case the given name, and prevent it from overlapping with a reserved
-- Haskell name.
lowCase :: Text -> Text
lowCase = replaceReservedName . forceCase toLower

forceCase :: (Char -> Char) -> Text -> Text
forceCase convert s = maybe "" (\(c, cs) -> Text.cons (convert c) cs)
                      (Text.uncons s)

imports = stack [
      "import Data.ByteString (ByteString)"
    , "import Data.Complex (Complex)"
    , "import Data.Int (Int8, Int16, Int32, Int64)"
    , "import Data.Word (Word8, Word16)"
    , "import Lens.Family2 ((.~), (&))"
    , "import TensorFlow.Build"
    , "import TensorFlow.BuildOp"
    , "import TensorFlow.Tensor"
    , "import TensorFlow.Types"
      ]

renderDef :: OpDef -> Doc
renderDef d =
  stack [
      haddocks
    , n <+> "::" <+> hang 0 (typeSig d)
    , n <+> hang 0 args <+> "|" <+> funcGuard <+> "=" </>  -- args are indented
            -- the body needs to be indented wrt the name
            indent indentation functionBody
    , extras  -- just for debug
    ]
  where
    n = strictText $ fixOpName (d ^. name)
    args = sep $ [hsName | (_, hsName) <- mandatoryAttrs] ++ tensorArgs
    tensorArgs = [strictText $ lowCase (a ^. name) | a <- d ^. inputArg]
    fixOpName = lowCase
    funcGuard = "eqLengthGuard" <+> brackets (commasep entries)
      where
        entries =
            [ parens $ quotedText nAttr <> comma <+>
              brackets (commasep $ toList $
              NE.map renderTensorName tensorNames)
            | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
            ]
        renderTensorName x = parens $ quotedText x <> comma <+>
                             "length" <+> strictText x
    -- Uses hang 0 to align the argument vertically on multiple lines.
    functionBody = buildFunction <+> parens (hang 0 (stack buildOpParts))
                                 </> indent indentation (sep tensorArgs)
    buildFunction
        | null outputListsSizes = "buildOp"
        | otherwise = "buildListOp" <+> brackets (commasep outputListsSizes)
    outputListsSizes = [ strictText numberAttrName
                       | o <- d ^. outputArg
                       , let numberAttrName = o ^. numberAttr
                       , not (Text.null numberAttrName) &&
                         numberAttrName `Map.member` mandatoryAttrMap d
                       ]
    buildOpParts =
        "opDef" <+> quotedText (d ^. name) :
        -- Renders tensor arguments.
        [ "& opAttr" <+> quotedText tfName <+>
          ".~ tensorType (undefined ::" <+> strictText hsName <> ")"
        | (tfName, (hsName, _)) <- Map.toList typeMap
        ] ++
        -- Renders mandatory attributes as function parameters.
        [ "& opAttr" <+> dquotes tfName <+> ".~" <+> hsName
        | (tfName, hsName) <- mandatoryAttrs
        ] ++
        -- Renders sizes of tensor list types having number_attr.
        [ "& opAttr" <+> quotedText nAttr <+> ".~" <+>
          "(fromIntegral (length" <+> strictText (head tensorNames) <> ") :: Int64)"
        | (nAttr, tensorNames) <- Map.toList $ numberAttrMap d
        ]
    mandatoryAttrs = [(strictText tf, strictText hs)
                     | (tf, (hs, _, _)) <- Map.toList (mandatoryAttrMap d)
                     ]
    haddocks = "-- |" <+> multilineComment (d ^. summary) (d ^. description)
    extras = enclose "{-\n" "\n-}" $
             strictText $ Text.pack $
             showMessage ((def :: OpDef)
                          & inputArg .~ (d ^. inputArg)
                          & outputArg .~ (d ^. outputArg)
                          & attr .~ (d ^. attr))
    typeMap = opDefTypeMap d

-- | Makes a quoted string doc out of the given text value.
quotedText :: Text.Text -> Doc
quotedText = dquotes . strictText

-- | typeSig renders the type signature of the given OpDef.
typeSig :: OpDef -> Doc
typeSig d =
    foralls <+> constraints <+/>
    signatureFold (mandatoryAttrInputs ++ tensorInputs ++ [outputs])
  where
    foralls | Map.null typeMap = empty
            | otherwise =
              "forall"
              <+> sep (refTypes ++ map (strictText . fst) (Map.elems typeMap))
              <+> "."
    constraints | Map.null typeMap = empty
                | otherwise =
                  tuple (concatMap
                         (\(t, aDef) ->
                           "TensorType" <+> strictText t
                           : maybeToList (oneOfRestrictions aDef t))
                         (Map.elems typeMap)) <+> "=>"
    tensorInputs = zipWith tensorArg refTypes (d ^. inputArg)
    refTypes = map (\x -> "v" <> int x) [1..length (d ^. inputArg)]
    tensorArg refType arg = wrapArg refType arg <+>
                            hang 0 ("-- ^" <+> argComment arg)
    -- Argument type is a list of tensors if number_attr is set;
    -- otherwise it's a single Tensor.
    wrapArg refType arg =
        if Text.null (arg ^. numberAttr) then typ else brackets typ
      where typ = tensorType refType arg
    tensorType refType arg =
      "Tensor" <+> refType <+> maybe directType strictText indirectType
      where
        indirectType = fmap fst (Map.lookup (arg ^. typeAttr) typeMap)
        directType = dtTypeToDoc (arg ^. type')
    outputs =
      case d ^. outputArg of
        []  -> "ControlNode"
        [o] -> wrappedOutput o <+> "-- ^" <+> argComment o
        os  -> renderTupleResult os
    wrappedOutput = wrapArg "Value"
    -- Tuple result case is rendered differently to give
    -- individual elements their own comments.
    renderTupleResult os =
        stack $ [ tuple (map wrappedOutput os)
                , flatten commentSummary
                ] ++ map commentDetails os
      where
        commentSummary = "-- ^" <+> tuple [bold (o ^. name) | o <- os]
        commentDetails o =
          stack [ "--"
                , "-- *" <+> argComment o
                ]
    signatureFold = folddoc (\x y -> x </> "->" <+> y)
    mandatoryAttrInputs = [
      dtTypeToDoc dtType <+>
          hang 0 ("-- ^" <+> argComment' tfName descr)
      | (tfName, (_, dtType, descr)) <- Map.toList $ mandatoryAttrMap d]
    typeMap = opDefTypeMap d

-- | Returns the type restriction for the given tensor type if the
-- set of allowed types is not empty (i.e. restricted).
oneOfRestrictions :: AttrDef -> Text -> Maybe Doc
oneOfRestrictions aDef tName = do
    typs <- onAttrType (^. templateRestrictions) aDef
    guard $ not $ null typs
    let typeList = commasep $ map strictText $
                   Set.toList $ Set.fromList $
                   map dtTypeToHaskell typs
    return $ "OneOf" <+> "'" <> brackets typeList <+> strictText tName

-- | Identifies the attributes used as tensor cardinalities. In such
-- cases a list of tensors is supplied as an input_arg. The number of
-- such inputs is communicated as a separate opAttr.
-- The result key is TensorFlow attribute name and the value is the
-- tensor names which have number_attr set to the result key.
numberAttrMap :: OpDef -> Map.Map Text.Text (NonEmpty Text.Text)
numberAttrMap d = Map.fromListWith (Semigroup.<>) [
    (nAttr, replaceReservedName (inp ^. name) :| [])
    | inp <- d ^. inputArg
    , nAttr <- [inp ^. numberAttr]
    , not (Text.null nAttr)
    ]

argComment :: OpDef'ArgDef -> Doc
argComment arg = argComment' (arg ^. name) (arg ^. description)

argComment' :: Text.Text -> Text.Text -> Doc
argComment' argName argDesc =
    bold argName <> splitMultilineText (":" <+>) argDesc

bold :: Text.Text -> Doc
bold n = strictText ("__" <> n <> "__")

opDefTypeMap :: OpDef -> Map.Map Text.Text (Text.Text, AttrDef)
opDefTypeMap d =
    Map.fromList [(n, (lowCase n, a)) | (n, a) <- attrList d, isType a]

attrList :: OpDef -> [(Text.Text, AttrDef)]
attrList d = [(a ^. name, attrDef a) | a <- d ^. attr]

isType :: AttrDef -> Bool
isType = fromMaybe False . onAttrType (const True)

-- | Applies the given function to the data type. Is this a Prism?
onAttrType :: (Template DataType -> a) -> AttrDef -> Maybe a
onAttrType f x = case x ^. attrTemplate of
    AttrSingle (AttrType a) -> Just (f a)
    _ -> Nothing

-- | mandatoryAttrMap contains the attributes chosen by
-- isMandatoryAttr, excluding those which are derived from list of
-- tensor arguments. The key is the TF name of the attribute. The
-- value tuple is (haskell name, TF type, attribute comment).
mandatoryAttrMap :: OpDef -> Map.Map Text.Text (Text.Text, DataType, Text.Text)
mandatoryAttrMap d =
    Map.fromList [ (n, (lowCase n, dtType, a ^. attrOriginal.description))
                 | (n, a) <- attrList d
                 , Just dtType <- [isMandatoryAttr a]
                 -- Excludes the attributes rendered as list lengths.
                 , n `Map.notMember` numberAttrMap d
                 ]

-- | Inspects the attribute and if it is one of the implemented
-- non-tensor values lacking default, then returns Just the TF type.
isMandatoryAttr :: AttrDef -> Maybe DataType
isMandatoryAttr x =
   case x ^. attrTemplate of
     AttrSingle (AttrBool y)  -> noDefault DT_BOOL y
     AttrSingle (AttrInt64 y) -> noDefault DT_INT64 y
     AttrSingle (AttrFloat y) -> noDefault DT_FLOAT y
     _ -> Nothing
   where
     noDefault typ y = maybe (Just typ) (const Nothing) (y ^. templateDefault)

dtTypeToDoc :: DataType -> Doc
dtTypeToDoc = strictText . dtTypeToHaskell

-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.
dtTypeToHaskell :: DataType -> Text.Text
dtTypeToHaskell DT_BOOL = "Bool"
dtTypeToHaskell DT_BFLOAT16 = "Data.Word.Word16"
dtTypeToHaskell DT_COMPLEX128 = "(Data.Complex.Complex Double)"
dtTypeToHaskell DT_COMPLEX64 = "(Data.Complex.Complex Float)"
dtTypeToHaskell DT_DOUBLE = "Double"
dtTypeToHaskell DT_FLOAT = "Float"
dtTypeToHaskell DT_INT16 = "Data.Int.Int16"
dtTypeToHaskell DT_INT32 = "Data.Int.Int32"
dtTypeToHaskell DT_INT64 = "Data.Int.Int64"
dtTypeToHaskell DT_INT8 = "Data.Int.Int8"
dtTypeToHaskell DT_QINT32 = "Data.Int.Int32"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT8 = "Data.Word.Word8"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QINT16 = "Data.Int.Int16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT16 = "Data.Word.Word16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_QUINT8 = "Data.Word.Word8"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_STRING = "Data.ByteString.ByteString"
dtTypeToHaskell DT_UINT16 = "Data.Word.Word16"
dtTypeToHaskell DT_HALF = "Data.Word.Word16"  -- TODO(gnezdo): make unique
dtTypeToHaskell DT_UINT8 = "Data.Word.Word8"
dtTypeToHaskell x =
    Text.pack $ "Unsupported type in dtTypeToHaskell: " ++ show x

-- | haddockComment escapes TensorFlow doc strings into haddock.
-- TODO(gnezdo): deal with the markup.
haddockComment :: Text.Text -> Doc
haddockComment = strictText

multilineComment :: Text.Text -> Text.Text -> Doc
multilineComment summary' detail =
    haddockComment summary' </>
    splitMultilineText insertParagraphAndComment detail
  where insertParagraphAndComment x = "--" </> "--" <+> x

-- | Converts the given multi-line detail string into
-- a multi-line haddock. Applies the given lead to the
-- first line. Returns an empty document for empty detail.
splitMultilineText :: (Doc -> Doc) -> Text.Text -> Doc
splitMultilineText lead detail =
  case Text.lines detail of
    [] -> empty
    (l : ls) -> stack $ lead (haddockComment l)
                      : map (("--" <+>) . haddockComment) ls

replaceReservedName :: Text -> Text
replaceReservedName n
    | n `Set.member` reservedKeywords = n <> "'"
    | otherwise = n

indentation = 4

reservedKeywords :: Set.Set Text
reservedKeywords = Set.fromList $
    -- Haskell2010 keywords:
    -- https://www.haskell.org/onlinereport/haskell2010/haskellch2.html#x7-180002.4
    -- We don't include keywords that are allowed to be variable names,
    -- in particular: "as", "forall", and "hiding".
    [ "case"
    , "class"
    , "data"
    , "default"
    , "deriving"
    , "do"
    , "else"
    , "foreign"
    , "if"
    , "import"
    , "in"
    , "infix"
    , "infixl"
    , "infixr"
    , "instance"
    , "let"
    , "module"
    , "newtype"
    , "of"
    , "then"
    , "type"
    , "where"
    ]
    ++  -- Nonstandard extensions
    [ "mdo"   -- RecursiveDo
    , "rec"   -- Arrows, RecursiveDo
    , "proc"  -- Arrows
    ]