{-# LANGUAGE OverloadedStrings #-}
module Snap.Util.CORS
(
applyCORS
, CORSOptions(..)
, defaultOptions
, OriginList(..)
, OriginSet, mkOriginSet, origins
, HashableURI(..), HashableMethod (..)
) where
import Control.Applicative
import Control.Monad (join, when)
import Data.CaseInsensitive (CI)
import Data.Hashable (Hashable(..))
import Data.Maybe (fromMaybe)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Network.URI (URI (..), URIAuth (..), parseURI)
import qualified Data.Attoparsec.ByteString.Char8 as Attoparsec
import qualified Data.ByteString.Char8 as S
import qualified Data.CaseInsensitive as CI
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import qualified Snap.Core as Snap
import Snap.Internal.Parsing (pTokens)
newtype OriginSet = OriginSet { OriginSet -> HashSet HashableURI
origins :: HashSet.HashSet HashableURI }
data OriginList
= Everywhere
| Nowhere
| Origins OriginSet
data CORSOptions m = CORSOptions
{ forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin :: m OriginList
, forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials :: m Bool
, :: m (HashSet.HashSet (CI S.ByteString))
, forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods :: m (HashSet.HashSet HashableMethod)
, :: HashSet.HashSet S.ByteString -> m (HashSet.HashSet S.ByteString)
}
defaultOptions :: Monad m => CORSOptions m
defaultOptions :: forall (m :: * -> *). Monad m => CORSOptions m
defaultOptions = CORSOptions
{ corsAllowOrigin :: m OriginList
corsAllowOrigin = OriginList -> m OriginList
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return OriginList
Everywhere
, corsAllowCredentials :: m Bool
corsAllowCredentials = Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
, corsExposeHeaders :: m (HashSet (CI ByteString))
corsExposeHeaders = HashSet (CI ByteString) -> m (HashSet (CI ByteString))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return HashSet (CI ByteString)
forall a. HashSet a
HashSet.empty
, corsAllowedMethods :: m (HashSet HashableMethod)
corsAllowedMethods = HashSet HashableMethod -> m (HashSet HashableMethod)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (HashSet HashableMethod -> m (HashSet HashableMethod))
-> HashSet HashableMethod -> m (HashSet HashableMethod)
forall a b. (a -> b) -> a -> b
$! HashSet HashableMethod
defaultAllowedMethods
, corsAllowedHeaders :: HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders = HashSet ByteString -> m (HashSet ByteString)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
}
defaultAllowedMethods :: HashSet.HashSet HashableMethod
defaultAllowedMethods :: HashSet HashableMethod
defaultAllowedMethods = [HashableMethod] -> HashSet HashableMethod
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList ([HashableMethod] -> HashSet HashableMethod)
-> [HashableMethod] -> HashSet HashableMethod
forall a b. (a -> b) -> a -> b
$ (Method -> HashableMethod) -> [Method] -> [HashableMethod]
forall a b. (a -> b) -> [a] -> [b]
map Method -> HashableMethod
HashableMethod
[ Method
Snap.GET, Method
Snap.POST, Method
Snap.PUT, Method
Snap.DELETE, Method
Snap.HEAD ]
applyCORS :: Snap.MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS :: forall (m :: * -> *). MonadSnap m => CORSOptions m -> m () -> m ()
applyCORS CORSOptions m
options m ()
m =
(Maybe (Maybe URI) -> Maybe URI
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Maybe (Maybe URI) -> Maybe URI)
-> (Maybe ByteString -> Maybe (Maybe URI))
-> Maybe ByteString
-> Maybe URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Maybe URI) -> Maybe ByteString -> Maybe (Maybe URI)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe URI
decodeOrigin (Maybe ByteString -> Maybe URI)
-> m (Maybe ByteString) -> m (Maybe URI)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Origin") m (Maybe URI) -> (Maybe URI -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m () -> (URI -> m ()) -> Maybe URI -> m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe m ()
m URI -> m ()
corsRequestFrom
where
corsRequestFrom :: URI -> m ()
corsRequestFrom URI
origin = do
originList <- CORSOptions m -> m OriginList
forall (m :: * -> *). CORSOptions m -> m OriginList
corsAllowOrigin CORSOptions m
options
if origin `inOriginList` originList
then Snap.method Snap.OPTIONS (preflightRequestFrom origin)
<|> handleRequestFrom origin
else m
preflightRequestFrom :: a -> m ()
preflightRequestFrom a
origin = do
maybeMethod <- (ByteString -> HashableMethod)
-> Maybe ByteString -> Maybe HashableMethod
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (String -> HashableMethod
parseMethod (String -> HashableMethod)
-> (ByteString -> String) -> ByteString -> HashableMethod
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
S.unpack) (Maybe ByteString -> Maybe HashableMethod)
-> m (Maybe ByteString) -> m (Maybe HashableMethod)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
CI ByteString -> m (Maybe ByteString)
getHeader CI ByteString
"Access-Control-Request-Method"
case maybeMethod of
Maybe HashableMethod
Nothing -> m ()
m
Just HashableMethod
method -> do
allowedMethods <- CORSOptions m -> m (HashSet HashableMethod)
forall (m :: * -> *). CORSOptions m -> m (HashSet HashableMethod)
corsAllowedMethods CORSOptions m
options
if method `HashSet.member` allowedMethods
then do
maybeHeaders <-
fromMaybe (Just HashSet.empty) . fmap splitHeaders
<$> getHeader "Access-Control-Request-Headers"
case maybeHeaders of
Maybe (HashSet ByteString)
Nothing -> m ()
m
Just HashSet ByteString
headers -> do
allowedHeaders <- CORSOptions m -> HashSet ByteString -> m (HashSet ByteString)
forall (m :: * -> *).
CORSOptions m -> HashSet ByteString -> m (HashSet ByteString)
corsAllowedHeaders CORSOptions m
options HashSet ByteString
headers
if not $ HashSet.null $
headers `HashSet.difference` allowedHeaders
then m
else do
addAccessControlAllowOrigin origin
addAccessControlAllowCredentials
commaSepHeader
"Access-Control-Allow-Headers"
id (HashSet.toList allowedHeaders)
commaSepHeader
"Access-Control-Allow-Methods"
(S.pack . show) (HashSet.toList allowedMethods)
else m
handleRequestFrom :: a -> m ()
handleRequestFrom a
origin = do
a -> m ()
forall {m :: * -> *} {a}. (MonadSnap m, Show a) => a -> m ()
addAccessControlAllowOrigin a
origin
m ()
addAccessControlAllowCredentials
exposeHeaders <- CORSOptions m -> m (HashSet (CI ByteString))
forall (m :: * -> *). CORSOptions m -> m (HashSet (CI ByteString))
corsExposeHeaders CORSOptions m
options
when (not $ HashSet.null exposeHeaders) $
commaSepHeader
"Access-Control-Expose-Headers"
CI.original (HashSet.toList exposeHeaders)
m
addAccessControlAllowOrigin :: a -> m ()
addAccessControlAllowOrigin a
origin =
CI ByteString -> ByteString -> m ()
forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
"Access-Control-Allow-Origin"
(Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ a -> String
forall a. Show a => a -> String
show a
origin)
addAccessControlAllowCredentials :: m ()
addAccessControlAllowCredentials = do
allowCredentials <- CORSOptions m -> m Bool
forall (m :: * -> *). CORSOptions m -> m Bool
corsAllowCredentials CORSOptions m
options
when (allowCredentials) $
addHeader "Access-Control-Allow-Credentials" "true"
decodeOrigin :: S.ByteString -> Maybe URI
decodeOrigin :: ByteString -> Maybe URI
decodeOrigin = (URI -> URI) -> Maybe URI -> Maybe URI
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URI -> URI
simplifyURI (Maybe URI -> Maybe URI)
-> (ByteString -> Maybe URI) -> ByteString -> Maybe URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe URI
parseURI (String -> Maybe URI)
-> (ByteString -> String) -> ByteString -> Maybe URI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8
addHeader :: CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k ByteString
v = (Response -> Response) -> m ()
forall (m :: * -> *). MonadSnap m => (Response -> Response) -> m ()
Snap.modifyResponse (CI ByteString -> ByteString -> Response -> Response
forall a. HasHeaders a => CI ByteString -> ByteString -> a -> a
Snap.addHeader CI ByteString
k ByteString
v)
commaSepHeader :: CI ByteString -> (a -> ByteString) -> [a] -> m ()
commaSepHeader CI ByteString
k a -> ByteString
f [a]
vs =
case [a]
vs of
[] -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
[a]
_ -> CI ByteString -> ByteString -> m ()
forall {m :: * -> *}.
MonadSnap m =>
CI ByteString -> ByteString -> m ()
addHeader CI ByteString
k (ByteString -> m ()) -> ByteString -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> ByteString
S.intercalate ByteString
", " ((a -> ByteString) -> [a] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map a -> ByteString
f [a]
vs)
getHeader :: CI ByteString -> m (Maybe ByteString)
getHeader = (Request -> Maybe ByteString) -> m (Maybe ByteString)
forall (m :: * -> *) a. MonadSnap m => (Request -> a) -> m a
Snap.getsRequest ((Request -> Maybe ByteString) -> m (Maybe ByteString))
-> (CI ByteString -> Request -> Maybe ByteString)
-> CI ByteString
-> m (Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CI ByteString -> Request -> Maybe ByteString
forall a. HasHeaders a => CI ByteString -> a -> Maybe ByteString
Snap.getHeader
splitHeaders :: ByteString -> Maybe (HashSet ByteString)
splitHeaders = (String -> Maybe (HashSet ByteString))
-> ([ByteString] -> Maybe (HashSet ByteString))
-> Either String [ByteString]
-> Maybe (HashSet ByteString)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe (HashSet ByteString) -> String -> Maybe (HashSet ByteString)
forall a b. a -> b -> a
const Maybe (HashSet ByteString)
forall a. Maybe a
Nothing) (HashSet ByteString -> Maybe (HashSet ByteString)
forall a. a -> Maybe a
Just (HashSet ByteString -> Maybe (HashSet ByteString))
-> ([ByteString] -> HashSet ByteString)
-> [ByteString]
-> Maybe (HashSet ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> HashSet ByteString
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList) (Either String [ByteString] -> Maybe (HashSet ByteString))
-> (ByteString -> Either String [ByteString])
-> ByteString
-> Maybe (HashSet ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
Parser [ByteString] -> ByteString -> Either String [ByteString]
forall a. Parser a -> ByteString -> Either String a
Attoparsec.parseOnly Parser [ByteString]
pTokens
mkOriginSet :: [URI] -> OriginSet
mkOriginSet :: [URI] -> OriginSet
mkOriginSet = HashSet HashableURI -> OriginSet
OriginSet (HashSet HashableURI -> OriginSet)
-> ([URI] -> HashSet HashableURI) -> [URI] -> OriginSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [HashableURI] -> HashSet HashableURI
forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList ([HashableURI] -> HashSet HashableURI)
-> ([URI] -> [HashableURI]) -> [URI] -> HashSet HashableURI
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
(URI -> HashableURI) -> [URI] -> [HashableURI]
forall a b. (a -> b) -> [a] -> [b]
map (URI -> HashableURI
HashableURI (URI -> HashableURI) -> (URI -> URI) -> URI -> HashableURI
forall b c a. (b -> c) -> (a -> b) -> a -> c
. URI -> URI
simplifyURI)
simplifyURI :: URI -> URI
simplifyURI :: URI -> URI
simplifyURI URI
uri = URI
uri { uriAuthority =
fmap simplifyURIAuth (uriAuthority uri)
, uriPath = ""
, uriQuery = ""
, uriFragment = ""
}
where simplifyURIAuth :: URIAuth -> URIAuth
simplifyURIAuth URIAuth
auth = URIAuth
auth { uriUserInfo = "" }
parseMethod :: String -> HashableMethod
parseMethod :: String -> HashableMethod
parseMethod String
"GET" = Method -> HashableMethod
HashableMethod Method
Snap.GET
parseMethod String
"POST" = Method -> HashableMethod
HashableMethod Method
Snap.POST
parseMethod String
"HEAD" = Method -> HashableMethod
HashableMethod Method
Snap.HEAD
parseMethod String
"PUT" = Method -> HashableMethod
HashableMethod Method
Snap.PUT
parseMethod String
"DELETE" = Method -> HashableMethod
HashableMethod Method
Snap.DELETE
parseMethod String
"TRACE" = Method -> HashableMethod
HashableMethod Method
Snap.TRACE
parseMethod String
"OPTIONS" = Method -> HashableMethod
HashableMethod Method
Snap.OPTIONS
parseMethod String
"CONNECT" = Method -> HashableMethod
HashableMethod Method
Snap.CONNECT
parseMethod String
"PATCH" = Method -> HashableMethod
HashableMethod Method
Snap.PATCH
parseMethod String
s = Method -> HashableMethod
HashableMethod (Method -> HashableMethod) -> Method -> HashableMethod
forall a b. (a -> b) -> a -> b
$ ByteString -> Method
Snap.Method (String -> ByteString
S.pack String
s)
newtype HashableURI = HashableURI URI
deriving (HashableURI -> HashableURI -> Bool
(HashableURI -> HashableURI -> Bool)
-> (HashableURI -> HashableURI -> Bool) -> Eq HashableURI
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HashableURI -> HashableURI -> Bool
== :: HashableURI -> HashableURI -> Bool
$c/= :: HashableURI -> HashableURI -> Bool
/= :: HashableURI -> HashableURI -> Bool
Eq)
instance Show HashableURI where
show :: HashableURI -> String
show (HashableURI URI
u) = URI -> String
forall a. Show a => a -> String
show URI
u
instance Hashable HashableURI where
hashWithSalt :: Int -> HashableURI -> Int
hashWithSalt Int
s (HashableURI (URI String
scheme Maybe URIAuth
authority String
path String
query String
fragment)) =
Int
s Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
scheme Int -> Maybe Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
(URIAuth -> Int) -> Maybe URIAuth -> Maybe Int
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap URIAuth -> Int
hashAuthority Maybe URIAuth
authority Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
path Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
query Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
fragment
where
hashAuthority :: URIAuth -> Int
hashAuthority (URIAuth String
userInfo String
regName String
port) =
Int
s Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
userInfo Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
regName Int -> String -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt`
String
port
inOriginList :: URI -> OriginList -> Bool
URI
_ inOriginList :: URI -> OriginList -> Bool
`inOriginList` OriginList
Nowhere = Bool
False
URI
_ `inOriginList` OriginList
Everywhere = Bool
True
URI
origin `inOriginList` (Origins (OriginSet HashSet HashableURI
xs)) =
URI -> HashableURI
HashableURI URI
origin HashableURI -> HashSet HashableURI -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet HashableURI
xs
newtype HashableMethod = HashableMethod Snap.Method
deriving (HashableMethod -> HashableMethod -> Bool
(HashableMethod -> HashableMethod -> Bool)
-> (HashableMethod -> HashableMethod -> Bool) -> Eq HashableMethod
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HashableMethod -> HashableMethod -> Bool
== :: HashableMethod -> HashableMethod -> Bool
$c/= :: HashableMethod -> HashableMethod -> Bool
/= :: HashableMethod -> HashableMethod -> Bool
Eq)
instance Hashable HashableMethod where
hashWithSalt :: Int -> HashableMethod -> Int
hashWithSalt Int
s (HashableMethod Method
Snap.GET) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
0 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.HEAD) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
1 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.POST) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
2 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.PUT) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
3 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.DELETE) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
4 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.TRACE) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
5 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.OPTIONS) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
6 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.CONNECT) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
7 :: Int)
hashWithSalt Int
s (HashableMethod Method
Snap.PATCH) = Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
8 :: Int)
hashWithSalt Int
s (HashableMethod (Snap.Method ByteString
m)) =
Int
s Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` (Int
9 :: Int) Int -> ByteString -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` ByteString
m
instance Show HashableMethod where
show :: HashableMethod -> String
show (HashableMethod Method
m) = Method -> String
forall a. Show a => a -> String
show Method
m