Skip to content
3 changes: 2 additions & 1 deletion Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ module Network.Socket

-- * Socket options
, SocketOption(SockOpt
,UnsupportedSocketOption
,Debug,ReuseAddr,SoDomain,Type,SoProtocol,SoError,DontRoute
,Broadcast,SendBuffer,RecvBuffer,KeepAlive,OOBInline
,TimeToLive,MaxSegment,NoDelay,Cork,Linger,ReusePort
Expand Down Expand Up @@ -244,7 +245,7 @@ module Network.Socket
,CmsgIdIPv6TClass
,CmsgIdIPv4PktInfo
,CmsgIdIPv6PktInfo
,CmsgIdFd)
,UnsupportedCmsgId)
-- ** APIs for control message
, lookupCmsg
, filterCmsg
Expand Down
19 changes: 19 additions & 0 deletions Network/Socket/Flag.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import qualified Data.Semigroup as Sem

import Network.Socket.Imports

{-
import Network.Socket.ReadShow

import qualified Text.Read as P
-}

-- | Message flags. To combine flags, use '(<>)'.
newtype MsgFlag = MsgFlag { fromMsgFlag :: CInt }
deriving (Show, Eq, Ord, Num, Bits)
Expand Down Expand Up @@ -78,3 +84,16 @@ pattern MSG_WAITALL = MsgFlag (#const MSG_WAITALL)
#else
pattern MSG_WAITALL = MsgFlag 0
#endif

{-
msgFlagPairs :: [Pair MsgFlag String]
msgFlagPairs =
[ (MSG_OOB, "MSG_OOB")
, (MSG_DONTROUTE, "MSG_DONTROUTE")
, (MSG_PEEK, "MSG_PEEK")
, (MSG_EOR, "MSG_EOR")
, (MSG_TRUNC, "MSG_TRUNC")
, (MSG_CTRUNC, "MSG_CTRUNC")
, (MSG_WAITALL, "MSG_WAITALL")
]
-}
76 changes: 71 additions & 5 deletions Network/Socket/Options.hsc
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}

#include "HsNet.h"
##include "HsNetDef.h"

module Network.Socket.Options (
SocketOption(SockOpt
,UnsupportedSocketOption
,Debug,ReuseAddr,SoDomain,Type,SoProtocol,SoError,DontRoute
,Broadcast,SendBuffer,RecvBuffer,KeepAlive,OOBInline,TimeToLive
,MaxSegment,NoDelay,Cork,Linger,ReusePort
Expand All @@ -25,12 +27,15 @@ module Network.Socket.Options (
, setSockOpt
) where

import qualified Text.Read as P

import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Utils (with)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Types
import Network.Socket.ReadShow

-----------------------------------------------------------------------------
-- Socket Properties
Expand All @@ -39,10 +44,15 @@ import Network.Socket.Types
--
-- The existence of a constructor does not imply that the relevant option
-- is supported on your system: see 'isSupportedSocketOption'
data SocketOption = SockOpt {
sockOptLevel :: !CInt
, sockOptName :: !CInt
} deriving (Eq, Show)
data SocketOption = SockOpt
#if __GLASGOW_HASKELL__ >= 806
!CInt -- ^ Option Level
Comment thread
archaephyrryx marked this conversation as resolved.
!CInt -- ^ Option Name
Comment thread
archaephyrryx marked this conversation as resolved.
#else
!CInt -- Option Level
!CInt -- Option Name
#endif
deriving (Eq)
Comment thread
archaephyrryx marked this conversation as resolved.

-- | Does the 'SocketOption' exist on this system?
isSupportedSocketOption :: SocketOption -> Bool
Expand All @@ -54,6 +64,9 @@ isSupportedSocketOption opt = opt /= SockOpt (-1) (-1)
getSocketType :: Socket -> IO SocketType
getSocketType s = unpackSocketType <$> getSockOpt s Type

pattern UnsupportedSocketOption :: SocketOption
pattern UnsupportedSocketOption = SockOpt (-1) (-1)

#ifdef SOL_SOCKET
-- | SO_DEBUG
pattern Debug :: SocketOption
Expand Down Expand Up @@ -141,7 +154,7 @@ pattern OOBInline :: SocketOption
#ifdef SO_OOBINLINE
pattern OOBInline = SockOpt (#const SOL_SOCKET) (#const SO_OOBINLINE)
#else
pattern OOBINLINE = SockOpt (-1) (-1)
pattern OOBInline = SockOpt (-1) (-1)
#endif
-- | SO_LINGER: timeout in seconds, 0 means disabling/disabled.
pattern Linger :: SocketOption
Expand Down Expand Up @@ -376,6 +389,59 @@ getSockOpt s (SockOpt level opt) = do
c_getsockopt fd level opt ptr ptr_sz
peek ptr


socketOptionPairs :: [Pair SocketOption String]
socketOptionPairs =
[ (UnsupportedSocketOption, "UnsupportedSocketOption")
, (Debug, "Debug")
, (ReuseAddr, "ReuseAddr")
, (SoDomain, "SoDomain")
, (Type, "Type")
, (SoProtocol, "SoProtocol")
, (SoError, "SoError")
, (DontRoute, "DontRoute")
, (Broadcast, "Broadcast")
, (SendBuffer, "SendBuffer")
, (RecvBuffer, "RecvBuffer")
, (KeepAlive, "KeepAlive")
, (OOBInline, "OOBInline")
, (Linger, "Linger")
, (ReusePort, "ReusePort")
, (RecvLowWater, "RecvLowWater")
, (SendLowWater, "SendLowWater")
, (RecvTimeOut, "RecvTimeOut")
, (SendTimeOut, "SendTimeOut")
, (UseLoopBack, "UseLoopBack")
, (MaxSegment, "MaxSegment")
, (NoDelay, "NoDelay")
, (UserTimeout, "UserTimeout")
, (Cork, "Cork")
, (TimeToLive, "TimeToLive")
, (RecvIPv4TTL, "RecvIPv4TTL")
, (RecvIPv4TOS, "RecvIPv4TOS")
, (RecvIPv4PktInfo, "RecvIPv4PktInfo")
, (IPv6Only, "IPv6Only")
, (RecvIPv6HopLimit, "RecvIPv6HopLimit")
, (RecvIPv6TClass, "RecvIPv6TClass")
, (RecvIPv6PktInfo, "RecvIPv6PktInfo")
]

socketOptionBijection :: Bijection SocketOption String
socketOptionBijection = Bijection{..}
where
cso = "CustomSockOpt"
unCSO = \(CustomSockOpt nm) -> nm
defFwd = defShow cso unCSO _show
defBwd = defRead cso CustomSockOpt _parse
pairs = socketOptionPairs

instance Show SocketOption where
show = forward socketOptionBijection

instance Read SocketOption where
readPrec = tokenize $ backward socketOptionBijection


foreign import CALLCONV unsafe "getsockopt"
c_getsockopt :: CInt -> CInt -> CInt -> Ptr a -> Ptr CInt -> IO CInt
foreign import CALLCONV unsafe "setsockopt"
Expand Down
37 changes: 36 additions & 1 deletion Network/Socket/Posix/Cmsg.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

Expand All @@ -19,6 +20,9 @@ import System.Posix.Types (Fd(..))

import Network.Socket.Imports
import Network.Socket.Types
import Network.Socket.ReadShow

import qualified Text.Read as P

-- | Control message (ancillary data) including a pair of level and type.
data Cmsg = Cmsg {
Expand All @@ -32,7 +36,11 @@ data Cmsg = Cmsg {
data CmsgId = CmsgId {
cmsgLevel :: !CInt
, cmsgType :: !CInt
} deriving (Eq, Show)
} deriving (Eq)

-- | Unsupported identifier
pattern UnsupportedCmsgId :: CmsgId
pattern UnsupportedCmsgId = CmsgId (-1) (-1)

-- | The identifier for 'IPv4TTL'.
pattern CmsgIdIPv4TTL :: CmsgId
Expand Down Expand Up @@ -220,3 +228,30 @@ instance Storable IPv6PktInfo where

instance ControlMessage Fd where
controlMessageId = CmsgIdFd

cmsgIdPairs :: [Pair CmsgId String]
cmsgIdPairs =
[ (UnsupportedCmsgId, "UnsupportedCmsgId")
, (CmsgIdIPv4TTL, "CmsgIdIPv4TTL")
, (CmsgIdIPv6HopLimit, "CmsgIdIPv6HopLimit")
, (CmsgIdIPv4TOS, "CmsgIdIPv4TOS")
, (CmsgIdIPv6TClass, "CmsgIdIPv6TClass")
, (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo")
, (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo")
, (CmsgIdFd, "CmsgIdFd")
]

cmsgIdBijection :: Bijection CmsgId String
cmsgIdBijection = Bijection{..}
where
defname = "CmsgId"
unId = \(CmsgId l t) -> (l,t)
defFwd = defShow defname unId _show
defBwd = defRead defname (uncurry CmsgId) _parse
pairs = cmsgIdPairs

instance Show CmsgId where
show = forward cmsgIdBijection

instance Read CmsgId where
readPrec = tokenize $ backward cmsgIdBijection
95 changes: 95 additions & 0 deletions Network/Socket/ReadShow.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}


module Network.Socket.ReadShow where

import qualified Text.Read as P

-- type alias for individual correspondences of a (possibly partial) bijection
type Pair a b = (a, b)

-- | helper function for equality on first tuple element
{-# INLINE eqFst #-}
eqFst :: Eq a => a -> (a, b) -> Bool
eqFst x = \(x',_) -> x' == x

-- | helper function for equality on snd tuple element
{-# INLINE eqSnd #-}
eqSnd :: Eq b => b -> (a, b) -> Bool
eqSnd y = \(_,y') -> y' == y

-- | Return RHS element that is paired with provided LHS,
-- or apply a default fallback function if the list is partial
lookForward :: Eq a => (a -> b) -> [Pair a b] -> a -> b
lookForward defFwd ps x
= case filter (eqFst x) ps of
(_,y):_ -> y
[] -> defFwd x

-- | Return LHS element that is paired with provided RHS,
-- or apply a default fallback function if the list is partial
lookBackward :: Eq b => (b -> a) -> [Pair a b] -> b -> a
lookBackward defBwd ps y
= case filter (eqSnd y) ps of
(x,_):_ -> x
[] -> defBwd y

data Bijection a b
= Bijection
{ defFwd :: a -> b
, defBwd :: b -> a
, pairs :: [Pair a b]
}

-- | apply a bijection over an LHS-value
forward :: (Eq a) => Bijection a b -> a -> b
forward Bijection{..} = lookForward defFwd pairs

-- | apply a bijection over an RHS-value
backward :: (Eq b) => Bijection a b -> b -> a
backward Bijection{..} = lookBackward defBwd pairs

-- | show function for Int-like types that encodes negative numbers
-- with leading '_' instead of '-'
_showInt :: (Show a, Num a, Ord a) => a -> String
_showInt n | n < 0 = let ('-':s) = show n in '_':s
| otherwise = show n

-- | parse function for Int-like types that interprets leading '_'
-- as if it were '-' instead
_readInt :: (Read a) => String -> a
_readInt ('_':s) = read $ '-':s
_readInt s = read s


-- | parse a quote-separated pair into a tuple of Int-like values
-- should not be used if either type might have
-- literal quote-characters in the Read pre-image
_parse :: (Read a, Read b) => String -> (a, b)
_parse xy =
let (xs, '\'':ys) = break (=='\'') xy
in (_readInt xs, _readInt ys)
{-# INLINE _parse #-}

-- | inverse function to _parse
-- show a tuple of Int-like values as quote-separated strings
_show :: (Show a, Num a, Ord a, Show b, Num b, Ord b) => (a, b) -> String
_show (x, y) = _showInt x ++ "'" ++ _showInt y
{-# INLINE _show #-}

defShow :: Eq a => String -> (a -> b) -> (b -> String) -> (a -> String)
defShow name unwrap sho = \x -> name ++ (sho . unwrap $ x)
{-# INLINE defShow #-}

defRead :: Read a => String -> (b -> a) -> (String -> b) -> (String -> a)
defRead name wrap red = \s ->
case splitAt (length name) s of
(x, sn) | x == name -> wrap $ red sn
_ -> error $ "defRead: unable to parse " ++ show s
{-# INLINE defRead #-}

-- | Apply a precedence-invariant one-token parse function within ReadPrec monad
tokenize :: (String -> a) -> P.ReadPrec a
tokenize f = P.lexP >>= \(P.Ident x) -> return $ f x
{-# INLINE tokenize #-}
Loading