|
From: <kr_...@us...> - 2003-09-06 19:59:10
|
Update of /cvsroot/htoolkit/HSQL/PostgreSQL
In directory sc8-pr-cvs1:/tmp/cvs-serv11008/PostgreSQL
Added Files:
HSQL.hsc
Log Message:
Support for PostgreSQL. The new implementation has better support for Sql<->Haskell data translation
--- NEW FILE: HSQL.hsc ---
-----------------------------------------------------------------------------------------
{-| Module : Database.PostgreSQL.HSQL
Copyright : (c) Krasimir Angelov 2003
License : BSD-style
Maintainer : ka2...@ya...
Stability : provisional
Portability : portable
The module provides interface to PostgreSQL database
-}
-----------------------------------------------------------------------------------------
module Database.PostgreSQL.HSQL
( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement
, catchSql -- :: IO a -> (SqlError -> IO a) -> IO a
, handleSql -- :: (SqlError -> IO a) -> IO a -> IO a
, sqlExceptions -- :: Exception -> Maybe SqlError
, connect -- :: String -> String -> String -> IO Connection
, disconnect -- :: Connection -> IO ()
, execute -- :: Connection -> String -> IO Statement
, query -- :: Connection -> String -> IO ()
, closeStatement -- :: Statement -> IO ()
, fetch -- :: Statement -> IO Bool
, inTransaction -- :: Connection -> (Connection -> IO a) -> IO a
, getFieldValueMB -- :: SqlBind a => Statement -> String -> IO (Maybe a)
, getFieldValue -- :: SqlBind a => Statement -> String -> IO a
, getFieldValue' -- :: SqlBind a => Statement -> String -> a -> IO a
, getFieldValueType -- :: Statement -> String -> (SqlType, Bool)
, getFieldsTypes -- :: Statement -> [(String, SqlType, Bool)]
, forEachRow -- :: (Statement -> s -> IO s) -> Statement -> s -> IO s
, forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO ()
, collectRows -- :: (Statement -> IO s) -> Statement -> IO [s]
, Point(..), Line(..), Path(..), Box(..), Circle(..), Polygon(..)
) where
import Data.Dynamic
import Data.IORef
import Foreign
import Foreign.C
import Control.Exception (throwDyn, catchDyn, dynExceptions, Exception(..))
import Control.Monad(when,unless,mplus)
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
# include <time.h>
#include <libpq-fe.h>
#include <postgres.h>
#include <catalog/pg_type.h>
type PGconn = Ptr ()
type PGresult = Ptr ()
type ConnStatusType = #type ConnStatusType
type ExecStatusType = #type ExecStatusType
type Oid = #type Oid
foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn
foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType
foreign import ccall "libpq-fe.h PQerrorMessage" pqErrorMessage :: PGconn -> IO CString
foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO ()
foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult
foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType
foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString
foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString
foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString
foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid
foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int
foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int
foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString
foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int
newtype Connection = Connection PGconn
data Statement
= Statement
{ pRes :: !PGresult
, tupleIndex :: IORef Int
, countTuples:: !Int
, connection :: !Connection
, fields :: ![FieldDef]
}
type FieldDef = (String, SqlType)
data SqlType
= SqlChar Int
| SqlVarChar Int
| SqlText
| SqlNumeric Int Int
| SqlSmallInt
| SqlInteger
| SqlReal
| SqlDouble
| SqlBool
| SqlBit Int
| SqlVarBit Int
| SqlTinyInt
| SqlBigInt
| SqlDate
| SqlTime
| SqlAbsTime
| SqlRelTime
| SqlTimeTZ
| SqlTimeInterval
| SqlAbsTimeInterval
| SqlTimeStamp
| SqlMoney
| SqlINetAddr
| SqlCIDRAddr
| SqlMacAddr
| SqlPoint
| SqlLSeg
| SqlPath
| SqlBox
| SqlPolygon
| SqlLine
| SqlCircle
| SqlUnknown
deriving (Eq, Show)
data SqlError
= SqlError
{ seState :: String
, seNativeError :: Int
, seErrorMsg :: String
}
| SqlNoData
| SqlBadTypeCast
{ seFieldName :: String
, seFieldType :: SqlType
}
| SqlFetchNull
{ seFieldName :: String
}
deriving (Show, Typeable)
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------
catchSql :: IO a -> (SqlError -> IO a) -> IO a
catchSql = catchDyn
handleSql :: (SqlError -> IO a) -> IO a -> IO a
handleSql h f = catchDyn f h
sqlExceptions :: Exception -> Maybe SqlError
sqlExceptions e = dynExceptions e >>= fromDynamic
-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------
connect :: String -> String -> String -> String -> IO Connection
connect server database user authentication = do
pServer <- newCString server
pDatabase <- newCString database
pUser <- newCString user
pAuthentication <- newCString authentication
pConn <- pqSetdbLogin pServer nullPtr nullPtr nullPtr pDatabase pUser pAuthentication
free pServer
free pUser
free pAuthentication
status <- pqStatus pConn
unless (status == (#const CONNECTION_OK)) (do
errMsg <- pqErrorMessage pConn >>= peekCString
pqFinish pConn
throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
return (Connection pConn)
disconnect :: Connection -> IO ()
disconnect (Connection pConn) = pqFinish pConn
-----------------------------------------------------------------------------------------
-- queries
-----------------------------------------------------------------------------------------
execute :: Connection -> String -> IO ()
execute conn@(Connection pConn) sqlExpr = do
pRes <- withCString sqlExpr (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
return ()
query :: Connection -> String -> IO Statement
query conn@(Connection pConn) query = do
pRes <- withCString query (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
defs <- if status == (#const PGRES_TUPLES_OK) then pgNFields pRes >>= getFieldDefs pRes 0 else return []
countTuples <- pqNTuples pRes;
tupleIndex <- newIORef (-1)
return (Statement {pRes=pRes, connection=conn, fields=defs, countTuples=countTuples, tupleIndex=tupleIndex})
where
getFieldDefs pRes i n
| i >= n = return []
| otherwise = do
name <- pgFName pRes i >>= peekCString
dataType <- pqFType pRes i
modifier <- pqFMod pRes i
defs <- getFieldDefs pRes (i+1) n
return ((name,mkSqlType dataType modifier):defs)
mkSqlType :: Oid -> Int -> SqlType
mkSqlType (#const BPCHAROID) size = SqlChar (size-4)
mkSqlType (#const VARCHAROID) size = SqlVarChar (size-4)
mkSqlType (#const NAMEOID) size = SqlVarChar 31
mkSqlType (#const TEXTOID) size = SqlText
mkSqlType (#const NUMERICOID) size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000)
mkSqlType (#const INT2OID) size = SqlSmallInt
mkSqlType (#const INT4OID) size = SqlInteger
mkSqlType (#const FLOAT4OID) size = SqlReal
mkSqlType (#const FLOAT8OID) size = SqlDouble
mkSqlType (#const BOOLOID) size = SqlBool
mkSqlType (#const BITOID) size = SqlBit size
mkSqlType (#const VARBITOID) size = SqlVarBit size
mkSqlType (#const BYTEAOID) size = SqlTinyInt
mkSqlType (#const INT8OID) size = SqlBigInt
mkSqlType (#const DATEOID) size = SqlDate
mkSqlType (#const TIMEOID) size = SqlTime
mkSqlType (#const TIMETZOID) size = SqlTimeTZ
mkSqlType (#const ABSTIMEOID) size = SqlAbsTime
mkSqlType (#const RELTIMEOID) size = SqlRelTime
mkSqlType (#const INTERVALOID) size = SqlTimeInterval
mkSqlType (#const TINTERVALOID) size = SqlAbsTimeInterval
mkSqlType (#const TIMESTAMPOID) size = SqlTimeStamp
mkSqlType (#const CASHOID) size = SqlMoney
mkSqlType (#const INETOID) size = SqlINetAddr
mkSqlType (#const 829) size = SqlMacAddr -- hack
mkSqlType (#const CIDROID) size = SqlCIDRAddr
mkSqlType (#const POINTOID) size = SqlPoint
mkSqlType (#const LSEGOID) size = SqlLSeg
mkSqlType (#const PATHOID) size = SqlPath
mkSqlType (#const BOXOID) size = SqlBox
mkSqlType (#const POLYGONOID) size = SqlPolygon
mkSqlType (#const LINEOID) size = SqlLine
mkSqlType (#const CIRCLEOID) size = SqlCircle
mkSqlType (#const UNKNOWNOID) size = SqlUnknown
fetch :: Statement -> IO Bool
fetch (Statement {countTuples=countTuples, tupleIndex=tupleIndex}) = do
index <- readIORef tupleIndex
let index' = index+1
if (index' >= countTuples)
then return False
else writeIORef tupleIndex index' >> return True
closeStatement :: Statement -> IO ()
closeStatement _ = return ()
-----------------------------------------------------------------------------------------
-- transactions
-----------------------------------------------------------------------------------------
inTransaction :: Connection -> (Connection -> IO a) -> IO a
inTransaction conn action = do
execute conn "begin"
r <- catchSql (action conn) (\err -> execute conn "rollback" >>= throwDyn err)
execute conn "commit"
return r
-----------------------------------------------------------------------------------------
-- binding
-----------------------------------------------------------------------------------------
class SqlBind a where
fromSqlValue :: SqlType -> String -> Maybe a
toSqlValue :: a -> String
instance SqlBind Int where
fromSqlValue SqlInteger s = Just (read s)
fromSqlValue SqlSmallInt s = Just (read s)
fromSqlValue _ _ = Nothing
toSqlValue s = show s
instance SqlBind Integer where
fromSqlValue SqlInteger s = Just (read s)
fromSqlValue SqlSmallInt s = Just (read s)
fromSqlValue SqlBigInt s = Just (read s)
fromSqlValue _ _ = Nothing
toSqlValue s = show s
instance SqlBind String where
fromSqlValue _ = Just
toSqlValue s = '\'' : foldr mapChar "'" s
where
mapChar '\'' s = '\\':'\'':s
mapChar '\n' s = '\\':'n':s
mapChar '\r' s = '\\':'r':s
mapChar '\t' s = '\\':'t':s
mapChar c s = c:s
instance SqlBind Bool where
fromSqlValue SqlBool s = Just (s == "t")
fromSqlValue _ _ = Nothing
toSqlValue True = "'t'"
toSqlValue False = "'f'"
instance SqlBind Double where
fromSqlValue (SqlNumeric _ _) s = Just (read s)
fromSqlValue SqlDouble s = Just (read s)
fromSqlValue SqlReal s = Just (read s)
fromSqlValue _ _ = Nothing
toSqlValue d = show d
mkClockTime :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> ClockTime
mkClockTime year mon mday hour min sec tz =
unsafePerformIO $ do
allocaBytes (#const sizeof(struct tm)) $ \ p_tm -> do
(#poke struct tm,tm_sec ) p_tm (fromIntegral sec :: CInt)
(#poke struct tm,tm_min ) p_tm (fromIntegral min :: CInt)
(#poke struct tm,tm_hour ) p_tm (fromIntegral hour :: CInt)
(#poke struct tm,tm_mday) p_tm (fromIntegral mday :: CInt)
(#poke struct tm,tm_mon ) p_tm (fromIntegral (mon-1) :: CInt)
(#poke struct tm,tm_year ) p_tm (fromIntegral (year-1900) :: CInt)
(#poke struct tm,tm_isdst) p_tm (-1 :: CInt)
t <- mktime p_tm
return (TOD (fromIntegral t + fromIntegral (tz-currTZ)) 0)
foreign import ccall unsafe mktime :: Ptr () -> IO CTime
{-# NOINLINE currTZ #-}
currTZ :: Int
currTZ = ctTZ (unsafePerformIO (getClockTime >>= toCalendarTime)) -- Hack
parseTZ :: ReadP Int
parseTZ = (char '+' >> readS_to_P reads) `mplus` (char '-' >> fmap negate (readS_to_P reads))
f_read :: ReadP a -> String -> Maybe a
f_read f s = case readP_to_S f s of {[(x,_)] -> Just x}
instance SqlBind ClockTime where
fromSqlValue SqlTimeTZ s = f_read getTime s
where
getTime :: ReadP ClockTime
getTime = do
hour <- readS_to_P reads
char ':'
minutes <- readS_to_P reads
char ':'
seconds <- readS_to_P reads
tz <- parseTZ
return (mkClockTime 1970 0 1 hour minutes seconds (tz*3600))
fromSqlValue SqlTime s = f_read getTime s
where
getTime :: ReadP ClockTime
getTime = do
hour <- readS_to_P reads
char ':'
minutes <- readS_to_P reads
char ':'
seconds <- readS_to_P reads
return (mkClockTime 1970 0 1 hour minutes seconds currTZ)
fromSqlValue SqlDate s = f_read getDate s
where
getDate :: ReadP ClockTime
getDate = do
year <- readS_to_P reads
satisfy (=='-')
month <- readS_to_P reads
satisfy (=='-')
day <- readS_to_P reads
return (mkClockTime year month day 0 0 0 currTZ)
fromSqlValue SqlTimeStamp s = f_read getTimeStamp s
where
getTimeStamp :: ReadP ClockTime
getTimeStamp = do
year <- readS_to_P reads
satisfy (=='-')
month <- readS_to_P reads
satisfy (=='-')
day <- readS_to_P reads
skipSpaces
hour <- readS_to_P reads
satisfy (==':')
minutes <- readS_to_P reads
satisfy (==':')
seconds <- readS_to_P reads
tz <- parseTZ
return (mkClockTime year month day hour minutes seconds (tz*3600))
fromSqlValue _ _ = Nothing
toSqlValue ct = '\'' : (shows (ctYear t) .
score .
shows (ctMonth t) .
score .
shows (ctDay t) .
space .
shows (ctHour t) .
colon .
shows (ctMin t) .
colon .
shows (ctSec t)) "'"
where
t = toUTCTime ct
score = showChar '.'
space = showChar ' '
colon = showChar ':'
data Point = Point Double Double deriving (Eq, Show)
data Line = Line Point Point deriving (Eq, Show)
data Path = OpenPath [Point] | ClosedPath [Point] deriving (Eq, Show)
data Box = Box Double Double Double Double deriving (Eq, Show)
data Circle = Circle Point Double deriving (Eq, Show)
data Polygon = Polygon [Point] deriving (Eq, Show)
instance SqlBind Point where
fromSqlValue SqlPoint s = case read s of
(x,y) -> Just (Point x y)
fromSqlValue _ _ = Nothing
toSqlValue (Point x y) = '\'' : shows (x,y) "'"
instance SqlBind Line where
fromSqlValue SqlLSeg s = case read s of
[(x1,y1),(x2,y2)] -> Just (Line (Point x1 y1) (Point x2 y2))
fromSqlValue _ _ = Nothing
toSqlValue (Line (Point x1 y1) (Point x2 y2)) = '\'' : shows [(x1,y1),(x2,y2)] "'"
instance SqlBind Path where
fromSqlValue SqlPath ('(':s) = case read ("["++init s++"]") of -- closed path
ps -> Just (ClosedPath (map (\(x,y) -> Point x y) ps))
fromSqlValue SqlPath s = case read s of -- closed path -- open path
ps -> Just (OpenPath (map (\(x,y) -> Point x y) ps))
fromSqlValue SqlLSeg s = case read s of
[(x1,y1),(x2,y2)] -> Just (OpenPath [(Point x1 y1), (Point x2 y2)])
fromSqlValue SqlPoint s = case read s of
(x,y) -> Just (ClosedPath [Point x y])
fromSqlValue _ _ = Nothing
toSqlValue (OpenPath ps) = '\'' : shows ps "'"
toSqlValue (ClosedPath ps) = "'(" ++ init (tail (show ps)) ++ "')"
instance SqlBind Box where
fromSqlValue SqlBox s = case read ("("++s++")") of
((x1,y1),(x2,y2)) -> Just (Box x1 y1 x2 y2)
fromSqlValue _ _ = Nothing
toSqlValue (Box x1 y1 x2 y2) = '\'' : shows ((x1,y1),(x2,y2)) "'"
instance SqlBind Polygon where
fromSqlValue SqlPolygon s = case read ("["++init (tail s)++"]") of
ps -> Just (Polygon (map (\(x,y) -> Point x y) ps))
fromSqlValue _ _ = Nothing
toSqlValue (Polygon ps) = "'(" ++ init (tail (show ps)) ++ "')"
instance SqlBind Circle where
fromSqlValue SqlCircle s = case read ("("++init (tail s)++")") of
((x,y),r) -> Just (Circle (Point x y) r)
fromSqlValue _ _ = Nothing
toSqlValue (Circle (Point x y) r) = "'<" ++ show (x,y) ++ "," ++ show r ++ "'>"
getFieldValueMB :: SqlBind a => Statement -> String -> IO (Maybe a)
getFieldValueMB (Statement {pRes=pRes, connection=conn, fields=fieldDefs, countTuples=countTuples, tupleIndex=tupleIndex}) name = do
index <- readIORef tupleIndex
when (index >= countTuples) (throwDyn SqlNoData)
let (sqlType,colNumber) = findFieldInfo name fieldDefs 0
isnull <- pqGetisnull pRes index colNumber
if isnull == 1
then return Nothing
else do
value <- pqGetvalue pRes index colNumber >>= peekCString
case fromSqlValue sqlType value of
Just v -> return (Just v)
Nothing -> throwDyn (SqlBadTypeCast name sqlType)
getFieldValue :: SqlBind a => Statement -> String -> IO a
getFieldValue stmt name = do
mb_v <- getFieldValueMB stmt name
case mb_v of
Nothing -> throwDyn (SqlFetchNull name)
Just a -> return a
getFieldValue' :: SqlBind a => Statement -> String -> a -> IO a
getFieldValue' stmt name def = do
mb_v <- getFieldValueMB stmt name
return (case mb_v of { Nothing -> def; Just a -> a })
getFieldValueType :: Statement -> String -> SqlType
getFieldValueType stmt name = sqlType
where
(sqlType,colNumber) = findFieldInfo name (fields stmt) 1
getFieldsTypes :: Statement -> [(String, SqlType)]
getFieldsTypes = fields
findFieldInfo :: String -> [FieldDef] -> Int -> (SqlType,Int)
findFieldInfo name [] colNumber = error ("Undefined column name \"" ++ name ++ "\"")
findFieldInfo name (fieldDef@(name',sqlType):fields) colNumber
| name == name' = (sqlType,colNumber)
| otherwise = findFieldInfo name fields (colNumber+1)
-----------------------------------------------------------------------------------------
-- helpers
-----------------------------------------------------------------------------------------
forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s
forEachRow f stmt s = do
success <- fetch stmt
if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s
forEachRow' :: (Statement -> IO ()) -> Statement -> IO ()
forEachRow' f stmt = do
success <- fetch stmt
if success then f stmt >> forEachRow' f stmt else closeStatement stmt
collectRows :: (Statement -> IO a) -> Statement -> IO [a]
collectRows f stmt = loop
where
loop = do
success <- fetch stmt
if success
then do
x <- f stmt
xs <- loop
return (x:xs)
else closeStatement stmt >> return []
|