From: <kr_...@us...> - 2004-01-05 00:41:07
|
Update of /cvsroot/htoolkit/HSQL/ODBC In directory sc8-pr-cvs1:/tmp/cvs-serv26990a/ODBC Modified Files: HSQL.hsc Log Message: Add tables and describe functions for ODBC backend Index: HSQL.hsc =================================================================== RCS file: /cvsroot/htoolkit/HSQL/ODBC/HSQL.hsc,v retrieving revision 1.12 retrieving revision 1.13 diff -C2 -d -r1.12 -r1.13 *** HSQL.hsc 4 Jan 2004 16:43:29 -0000 1.12 --- HSQL.hsc 5 Jan 2004 00:41:03 -0000 1.13 *************** *** 32,35 **** --- 32,37 ---- , forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO () , collectRows -- :: (Statement -> IO s) -> Statement -> IO [s] + , tables -- :: Connection -> IO [String] + , describe -- :: Connection -> String -> IO [(String, SqlType, Bool)] ) where *************** *** 81,84 **** --- 83,88 ---- foreign import stdcall "HsODBC.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN foreign import stdcall "HsODBC.h SQLGetData" sqlGetData :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr () -> SQLINTEGER -> Ptr SQLINTEGER -> IO SQLRETURN + foreign import stdcall "HsODBC.h SQLTables" sqlTables :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN + foreign import stdcall "HsODBC.h SQLColumns" sqlColumns :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN #else foreign import ccall "HsODBC.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN *************** *** 99,102 **** --- 103,108 ---- foreign import ccall "HsODBC.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN foreign import ccall "HsODBC.h SQLGetData" sqlGetData :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr () -> SQLINTEGER -> Ptr SQLINTEGER -> IO SQLRETURN + foreign import ccall "HsODBC.h SQLTables" sqlTables :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN + foreign import ccall "HsODBC.h SQLColumns" sqlColumns :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN #endif *************** *** 181,185 **** sqlSuccess :: SQLRETURN -> Bool ! sqlSuccess res = (res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || (res == (#const SQL_NO_DATA)) --- 187,191 ---- sqlSuccess :: SQLRETURN -> Bool ! sqlSuccess res = (res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || (res == (#const SQL_NO_DATA)) *************** *** 255,259 **** sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC return () ! ----------------------------------------------------------------------------------------- -- queries --- 261,265 ---- sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC return () ! ----------------------------------------------------------------------------------------- -- queries *************** *** 278,285 **** free pFIELD ! -- | Executes the statement and returns a 'Statement' value which represents the result set ! query :: Connection -> String -> IO Statement ! query conn@(Connection {hDBC=hDBC}) query = do ! pFIELD <- mallocBytes (#const sizeof(FIELD)) res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD) unless (sqlSuccess res) (free pFIELD) --- 284,290 ---- free pFIELD ! withStatement :: Connection -> (HSTMT -> IO SQLRETURN) -> IO Statement ! withStatement conn@(Connection {hDBC=hDBC}) f = do ! pFIELD <- mallocBytes (#const sizeof(FIELD)) res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD) unless (sqlSuccess res) (free pFIELD) *************** *** 289,296 **** unless (sqlSuccess res) (free pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res ! pQuery <- newCString query ! res <- sqlExecDirect hSTMT pQuery (length query) ! free pQuery ! handleResult res sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult count <- (#peek FIELD, fieldsCount) pFIELD --- 294,298 ---- unless (sqlSuccess res) (free pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res ! res <- f hSTMT >>= handleResult sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult count <- (#peek FIELD, fieldsCount) pFIELD *************** *** 316,340 **** (fields, fullBufSize) <- getFieldDefs hSTMT pFIELD (n+1) count return ((name,sqlType,toBool nullable):fields, max bufSize fullBufSize) - - mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> (SqlType, SQLINTEGER) - mkSqlType (#const SQL_CHAR) size _ = (SqlChar (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) - mkSqlType (#const SQL_VARCHAR) size _ = (SqlVarChar (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) - mkSqlType (#const SQL_LONGVARCHAR) size _ = (SqlLongVarChar (fromIntegral size), 1) -- dummy bufSize - mkSqlType (#const SQL_DECIMAL) size prec = (SqlDecimal (fromIntegral size) (fromIntegral prec), (#const sizeof(SQLDOUBLE))) - mkSqlType (#const SQL_NUMERIC) size prec = (SqlNumeric (fromIntegral size) (fromIntegral prec), (#const sizeof(SQLDOUBLE))) - mkSqlType (#const SQL_SMALLINT) _ _ = (SqlSmallInt, (#const sizeof(SQLSMALLINT))) - mkSqlType (#const SQL_INTEGER) _ _ = (SqlInteger, (#const sizeof(SQLINTEGER))) - mkSqlType (#const SQL_REAL) _ _ = (SqlReal, (#const sizeof(SQLDOUBLE))) - mkSqlType (#const SQL_DOUBLE) _ _ = (SqlDouble, (#const sizeof(SQLDOUBLE))) - mkSqlType (#const SQL_BIT) _ _ = (SqlBit, (#const sizeof(SQLINTEGER))) - mkSqlType (#const SQL_TINYINT) _ _ = (SqlTinyInt, (#const sizeof(SQLSMALLINT))) - mkSqlType (#const SQL_BIGINT) _ _ = (SqlBigInt, (#const sizeof(SQLINTEGER))) - mkSqlType (#const SQL_BINARY) size _ = (SqlBinary (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) - mkSqlType (#const SQL_VARBINARY) size _ = (SqlVarBinary (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) - mkSqlType (#const SQL_LONGVARBINARY)size _ = (SqlLongVarBinary (fromIntegral size), 1) -- dummy bufSize - mkSqlType (#const SQL_DATE) _ _ = (SqlDate, (#const sizeof(SQL_DATE_STRUCT))) - mkSqlType (#const SQL_TIME) _ _ = (SqlTime, (#const sizeof(SQL_TIME_STRUCT))) - mkSqlType (#const SQL_TIMESTAMP) _ _ = (SqlTimeStamp, (#const sizeof(SQL_TIMESTAMP_STRUCT))) {-# NOINLINE fetch #-} --- 318,346 ---- (fields, fullBufSize) <- getFieldDefs hSTMT pFIELD (n+1) count return ((name,sqlType,toBool nullable):fields, max bufSize fullBufSize) + mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> (SqlType, SQLINTEGER) + mkSqlType (#const SQL_CHAR) size _ = (SqlChar (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) + mkSqlType (#const SQL_VARCHAR) size _ = (SqlVarChar (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) + mkSqlType (#const SQL_LONGVARCHAR) size _ = (SqlLongVarChar (fromIntegral size), 1) -- dummy bufSize + mkSqlType (#const SQL_DECIMAL) size prec = (SqlDecimal (fromIntegral size) (fromIntegral prec), (#const sizeof(SQLDOUBLE))) + mkSqlType (#const SQL_NUMERIC) size prec = (SqlNumeric (fromIntegral size) (fromIntegral prec), (#const sizeof(SQLDOUBLE))) + mkSqlType (#const SQL_SMALLINT) _ _ = (SqlSmallInt, (#const sizeof(SQLSMALLINT))) + mkSqlType (#const SQL_INTEGER) _ _ = (SqlInteger, (#const sizeof(SQLINTEGER))) + mkSqlType (#const SQL_REAL) _ _ = (SqlReal, (#const sizeof(SQLDOUBLE))) + mkSqlType (#const SQL_DOUBLE) _ _ = (SqlDouble, (#const sizeof(SQLDOUBLE))) + mkSqlType (#const SQL_BIT) _ _ = (SqlBit, (#const sizeof(SQLINTEGER))) + mkSqlType (#const SQL_TINYINT) _ _ = (SqlTinyInt, (#const sizeof(SQLSMALLINT))) + mkSqlType (#const SQL_BIGINT) _ _ = (SqlBigInt, (#const sizeof(SQLINTEGER))) + mkSqlType (#const SQL_BINARY) size _ = (SqlBinary (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) + mkSqlType (#const SQL_VARBINARY) size _ = (SqlVarBinary (fromIntegral size), (#const sizeof(SQLCHAR))*(fromIntegral size+1)) + mkSqlType (#const SQL_LONGVARBINARY)size _ = (SqlLongVarBinary (fromIntegral size), 1) -- dummy bufSize + mkSqlType (#const SQL_DATE) _ _ = (SqlDate, (#const sizeof(SQL_DATE_STRUCT))) + mkSqlType (#const SQL_TIME) _ _ = (SqlTime, (#const sizeof(SQL_TIME_STRUCT))) + mkSqlType (#const SQL_TIMESTAMP) _ _ = (SqlTimeStamp, (#const sizeof(SQL_TIMESTAMP_STRUCT))) + + -- | Executes the statement and returns a 'Statement' value which represents the result set + query :: Connection -> String -> IO Statement + query conn q = withStatement conn doQuery + where doQuery hSTMT = withCStringLen q (uncurry (sqlExecDirect hSTMT)) {-# NOINLINE fetch #-} *************** *** 346,352 **** handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res return (res /= (#const SQL_NO_DATA)) ! ! -- | 'closeStatement' stops processing associated with a specific statement, closes any open cursors ! -- associated with the statement, discards pending results, and frees all resources associated with -- the statement. closeStatement :: Statement -> IO () --- 352,358 ---- handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res return (res /= (#const SQL_NO_DATA)) ! ! -- | 'closeStatement' stops processing associated with a specific statement, closes any open cursors ! -- associated with the statement, discards pending results, and frees all resources associated with -- the statement. closeStatement :: Statement -> IO () *************** *** 356,359 **** --- 362,409 ---- ----------------------------------------------------------------------------------------- + -- getting table and column info + ----------------------------------------------------------------------------------------- + + -- | List all tables in the database. + tables :: Connection -- ^ Database connection + -> IO [String] -- ^ The names of all tables in the database. + tables conn = do + stmt <- withStatement conn sqlTables' + -- SQLTables returns: + -- Column name # Type + -- TABLE_NAME 3 VARCHAR + collectRows (\s -> getFieldValue' s "TABLE_NAME" "") stmt + where sqlTables' hSTMT = sqlTables hSTMT nullPtr 0 nullPtr 0 nullPtr 0 nullPtr 0 + + -- | List all columns in a table along with their types and @nullable@ flags + describe :: Connection -- ^ Database connection + -> String -- ^ Name of a database table + -> IO [(String, SqlType, Bool)] -- ^ @[(name, type, nullable)]@ + describe conn table = do + stmt <- withStatement conn (\hSTMT -> sqlColumns' hSTMT table) + collectRows getColumnInfo stmt + where + sqlColumns' hSTMT table = + withCStringLen table (\(pTable,len) -> + sqlColumns hSTMT nullPtr 0 nullPtr 0 pTable (fromIntegral len) nullPtr 0) + -- SQLColumns returns: + -- Column name # Type + -- COLUMN_NAME 4 Varchar not NULL + -- DATA_TYPE 5 Smallint not NULL + -- COLUMN_SIZE 7 Integer + -- DECIMAL_DIGITS 9 Smallint + -- NULLABLE 11 Smallint not NULL + getColumnInfo stmt = + do + name <- getFieldValue stmt "COLUMN_NAME" + (t::Int) <- getFieldValue stmt "DATA_TYPE" + (size::Int) <- getFieldValue' stmt "COLUMN_SIZE" 0 + (prec::Int) <- getFieldValue' stmt "DECIMAL_DIGITS" 0 + (n::Int) <- getFieldValue stmt "NULLABLE" + let (sqlType,_) = mkSqlType (fromIntegral t) (fromIntegral size) (fromIntegral prec) + nullable = n /= (#const SQL_NO_NULLS) + return (name, sqlType, nullable) + + ----------------------------------------------------------------------------------------- -- transactions ----------------------------------------------------------------------------------------- *************** *** 391,395 **** toSqlValue val = show val ! instance SqlBind Integer where fromSqlValue SqlInteger ptr size = do --- 441,445 ---- toSqlValue val = show val ! instance SqlBind Integer where fromSqlValue SqlInteger ptr size = do |