mirror of
https://github.com/tensorflow/haskell.git
synced 2024-06-02 19:13:34 +02:00
Fixes.
This commit is contained in:
parent
14a39f3f49
commit
63e0fae505
|
@ -211,11 +211,14 @@ whereClause :: [Attr (NonEmpty Name)] -> [Doc]
|
|||
whereClause [] = []
|
||||
whereClause as = [indent 2 $ "where" </> indent 2 (stack $ map defineLengthAttr as)]
|
||||
where
|
||||
defineLengthAttr a = renderHaskellName (attrName a) <+> "="
|
||||
defineLengthAttr a = renderHaskellAttrName a <+> "="
|
||||
<+> "fromIntegral (length"
|
||||
<+> renderHaskellName (NE.head $ attrInfo a)
|
||||
<> ") :: Int64"
|
||||
|
||||
renderHaskellAttrName :: Attr a -> Doc
|
||||
renderHaskellAttrName = renderHaskellName . attrName
|
||||
|
||||
functionBody :: ParsedOp -> Doc
|
||||
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
||||
</> indent indentation (sep tensorArgs)
|
||||
|
@ -246,9 +249,9 @@ functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
|
|||
tensorArgs = renderHaskellName . parsedArgName <$> parsedInputs pOp
|
||||
inferredTypeExpr a
|
||||
| typeParamIsList $ attrInfo a
|
||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellName (attrName a)
|
||||
= "fromTensorTypes (Proxy :: Proxy" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
| otherwise = "tensorType (undefined ::" <+> renderHaskellName (attrName a)
|
||||
| otherwise = "tensorType (undefined ::" <+> renderHaskellAttrName a
|
||||
<> ")"
|
||||
|
||||
-- | Write a comment with the inputs/outputs/attributes in proto format, for
|
||||
|
@ -278,7 +281,7 @@ typeSig pOp = constraints
|
|||
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
|
||||
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
|
||||
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
|
||||
++ [renderHaskellName $ attrName n | n <- inferredTypeAttrs pOp]
|
||||
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
|
||||
classConstraints = tuple $ map tensorArgConstraint
|
||||
$ inferredTypeAttrs pOp
|
||||
signatureFold = folddoc (\x y -> x </> "->" <+> y)
|
||||
|
@ -362,11 +365,11 @@ tensorArgConstraint a = case attrInfo a of
|
|||
TypeParam True Nothing -> "TensorTypes" <+> n
|
||||
TypeParam True (Just as) -> "OneOfs" <+> typeList as <+> n
|
||||
where
|
||||
n = renderHaskellName $ attrName a
|
||||
n = renderHaskellAttrName a
|
||||
-- Produces a type-level list, e.g.: '[Int32,Int64,Float]
|
||||
typeList = ("'" <>) . brackets . commasep . map strictText .
|
||||
Set.toList . Set.fromList .
|
||||
map dtTypeToHaskell
|
||||
map dtTypeToHaskell . toList
|
||||
|
||||
-- NOTE: The cases of this function should be kept in sync with
|
||||
-- TensorFlow.Types.AllTensorTypes.
|
||||
|
|
|
@ -105,7 +105,7 @@ data AttrBaseType = AttrBytes | AttrInt64 | AttrFloat | AttrBool
|
|||
|
||||
data TypeParam = TypeParam
|
||||
{ typeParamIsList :: Bool
|
||||
, typeParamRestrictions :: Maybe [DataType]
|
||||
, typeParamRestrictions :: Maybe (NonEmpty DataType)
|
||||
-- ^ The list of allowed types (see: TensorFlow.Types.OneOf).
|
||||
-- If 'Nothing', then any type is acceptable.
|
||||
}
|
||||
|
@ -273,9 +273,7 @@ getInferredTypeAttr argTypeParams a
|
|||
| a ^. type' == "list(type)" = Just $ TypeParam True allowed
|
||||
| otherwise = Nothing
|
||||
where
|
||||
allowed = case a ^. allowedValues . list . type' of
|
||||
[] -> Nothing
|
||||
as -> Just as
|
||||
allowed = nonEmpty (a ^. allowedValues . list . type')
|
||||
|
||||
getArgTypeParam :: ParsedArgCase -> Maybe Name
|
||||
getArgTypeParam SimpleArg { argType = ArgTypeAttr n} = Just n
|
||||
|
|
|
@ -49,7 +49,7 @@ enqueue q =
|
|||
dequeue :: forall as . TensorTypes as
|
||||
=> Queue as
|
||||
-> Build (TensorList Ref as)
|
||||
-- ^ Dequeued tensors. They are paired in a sense
|
||||
-- ^ Dequeued tensors. They are coupled in a sense
|
||||
-- that values appear together, even if they are
|
||||
-- not consumed together.
|
||||
dequeue q =
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
-- limitations under the License.
|
||||
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
@ -406,6 +407,11 @@ type family Map f as where
|
|||
instance All Eq (Map f as) => Eq (ListOf f as) where
|
||||
Nil == Nil = True
|
||||
(x :| xs) == (y :| ys) = x == y && xs == ys
|
||||
-- Newer versions of GHC use the GADT to tell that the previous cases are
|
||||
-- exhaustive.
|
||||
#if _GLASGOW_HASKELL__ < 800
|
||||
_ == _ = False
|
||||
#endif
|
||||
|
||||
instance All Show (Map f as) => Show (ListOf f as) where
|
||||
showsPrec _ Nil = showString "Nil"
|
||||
|
|
Loading…
Reference in New Issue
Block a user