1
0
mirror of https://github.com/tensorflow/haskell.git synced 2024-06-02 19:13:34 +02:00
This commit is contained in:
Judah Jacobson 2017-02-27 09:47:33 -08:00
parent 14a39f3f49
commit 63e0fae505
4 changed files with 18 additions and 11 deletions

View File

@ -211,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
whereClause [] = []
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
where
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
defineLengthAttr a = renderHaskellAttrName a <+> "="
<+> "fromIntegral (length"
<+> renderHaskellName (NE.head $ attrInfo a)
<> ") :: Int64"
renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName
functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
@ -246,9 +249,9 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
inferredTypeExpr a
| typeParamIsList $ attrInfo a
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellName (attrName a)
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
<> ")"
| otherwise = "tensorType (undefined ::" <+> renderHaskellName (attrName a)
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
<> ")"
-- | Write a comment with the inputs/outputs/attributes in proto format, for
@ -278,7 +281,7 @@ typeSig pOp = constraints
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ map tensorArgConstraint
$ inferredTypeAttrs pOp
signatureFold = folddoc (\x y -> x </> "->" <+> y)
@ -362,11 +365,11 @@ tensorArgConstraint a = case attrInfo a of
TypeParam True Nothing -> "TensorTypes" <+> n
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
where
n = renderHaskellName $ attrName a
n = renderHaskellAttrName a
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
typeList = ("'" <>) . brackets . commasep . map strictText .
Set.toList . Set.fromList .
map dtTypeToHaskell
map dtTypeToHaskell . toList
-- NOTE: The cases of this function should be kept in sync with
-- TensorFlow.Types.AllTensorTypes.

View File

@ -105,7 +105,7 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
data TypeParam = TypeParam
{ typeParamIsList :: Bool
, typeParamRestrictions :: Maybe [DataType]
, typeParamRestrictions :: Maybe (NonEmpty DataType)
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
-- If 'Nothing', then any type is acceptable.
}
@ -273,9 +273,7 @@ getInferredTypeAttr argTypeParams a
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
| otherwise = Nothing
where
allowed = case a ^. allowedValues . list . type' of
[] -> Nothing
as -> Just as
allowed = nonEmpty (a ^. allowedValues . list . type')
getArgTypeParam :: ParsedArgCase -> Maybe Name
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n

View File

@ -49,7 +49,7 @@ enqueue q =
dequeue :: forall as . TensorTypes as
=> Queue as
-> Build (TensorList Ref as)
-- ^ Dequeued tensors. They are paired in a sense
-- ^ Dequeued tensors. They are coupled in a sense
-- that values appear together, even if they are
-- not consumed together.
dequeue q =

View File

@ -13,6 +13,7 @@
-- limitations under the License.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
@ -406,6 +407,11 @@ type family Map f as where
instance All Eq (Map f as) => Eq (ListOf f as) where
Nil == Nil = True
(x :| xs) == (y :| ys) = x == y && xs == ys
-- Newer versions of GHC use the GADT to tell that the previous cases are
-- exhaustive.
#if _GLASGOW_HASKELL__ < 800
_ == _ = False
#endif
instance All Show (Map f as) => Show (ListOf f as) where
showsPrec _ Nil = showString "Nil"