diff --git a/locopy/utility.py b/locopy/utility.py index d19bf60..a903b41 100644 --- a/locopy/utility.py +++ b/locopy/utility.py @@ -251,14 +251,14 @@ def find_column_type(dataframe, warehouse_type: str): Following is the list of pandas data types that the function checks and their mapping in sql: - - bool -> boolean - - datetime64[ns] -> timestamp + - bool/pd.BooleanDtype -> boolean + - datetime64[ns, ] -> timestamp - M8[ns] -> timestamp - - int -> int - - float -> float + - int/pd.Int64Dtype -> int + - float/pd.Float64Dtype -> float - float object -> float - datetime object -> timestamp - - object -> varchar + - object/pd.StringDtype -> varchar For all other data types, the column will be mapped to varchar type. @@ -313,9 +313,9 @@ def validate_float_object(column): data = dataframe[column].dropna().reset_index(drop=True) if data.size == 0: column_type.append("varchar") - elif data.dtype in ["datetime64[ns]", "M8[ns]"]: + elif (data.dtype in ["datetime64[ns]", "M8[ns]"]) or (re.match("(datetime64\[ns\,\W)([a-zA-Z]+)(\])",str(data.dtype))): column_type.append("timestamp") - elif data.dtype == "bool": + elif str(data.dtype).lower().startswith("bool"): column_type.append("boolean") elif str(data.dtype).startswith("object"): data_type = validate_float_object(data) or validate_date_object(data) @@ -323,9 +323,9 @@ def validate_float_object(column): column_type.append("varchar") else: column_type.append(data_type) - elif str(data.dtype).startswith("int"): + elif str(data.dtype).lower().startswith("int"): column_type.append("int") - elif str(data.dtype).startswith("float"): + elif str(data.dtype).lower().startswith("float"): column_type.append("float") else: column_type.append("varchar") diff --git a/tests/test_utility.py b/tests/test_utility.py index e4139ac..bd32612 100644 --- a/tests/test_utility.py +++ b/tests/test_utility.py @@ -26,6 +26,7 @@ from itertools import cycle from pathlib import Path from unittest import mock +import datetime import pytest @@ -340,6 +341,52 @@ def test_find_column_type(): assert find_column_type(input_text, "snowflake") == output_text_snowflake assert find_column_type(input_text, "redshift") == output_text_redshift +def test_find_column_type_new(): + + from decimal import Decimal + + import pandas as pd + + input_text = pd.DataFrame.from_dict( + { + "a": [1], + "b": [pd.Timestamp('2017-01-01T12+0')], + "c": [1.2], + "d": ["a"], + "e": [True] + } +) + + input_text = input_text.astype( + dtype={ + "a": pd.Int64Dtype(), + "b": pd.DatetimeTZDtype(tz=datetime.timezone.utc), + "c": pd.Float64Dtype(), + "d": pd.StringDtype(), + "e": pd.BooleanDtype() + } + ) + + output_text_snowflake = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } + + output_text_redshift = { + "a": "int", + "b": "timestamp", + "c": "float", + "d": "varchar", + "e": "boolean", + } + + assert find_column_type(input_text, "snowflake") == output_text_snowflake + assert find_column_type(input_text, "redshift") == output_text_redshift + + def test_get_ignoreheader_number(): assert (