import dai
import pandas as pd
from datetime import timedelta

def _delete_extra_column(data):
    """
    有时智能体会抽取更多的数据导致出现多余的字段被打上形如'_2'的记号, 如volume_2
    所以要把他们识别出来并删除
    """
    data = data.drop(columns=data.columns[data.columns.str.contains('_2')])
    return data

def get_data(feature_name, start_date, end_date):
    """因子抽取"""
    start_date = (pd.to_datetime(start_date) - timedelta(days=250)).strftime('%Y-%m-%d')
    columns = feature_name + ['date', 'instrument']
    feature_name = ', '.join(feature_name)
    sql = f"""
    SELECT date, instrument, {feature_name}
    FROM cn_stock_prefactors
    WHERE st_status = 0
    AND suspended = 0
    AND close / adjust_factor > 5
    AND instrument NOT LIKE '%B%'
    """
    data = dai.query(sql, filters={'date': [start_date, end_date]}).df()
    data = _delete_extra_column(data)
    return data

def data_filter(data, feature_name, direction, threshold):
    """数据筛选"""
    if direction == '>':
        data = data.query(f"{feature_name} > {threshold}")
    elif direction == '=':
        data = data.query(f"{feature_name} == {threshold}")
    elif direction == '<':
        data = data.query(f"{feature_name} < {threshold}")
    return data

def merge_data(data_1, data_2, how):
    """数据拼接"""
    df_1 = data_1
    df_2 = data_2

    # 拼接之前确保两张表没有重复的列
    c_a, c_b = set(df_1.columns), set(df_2.columns)
    if c_a < c_b:          # df_a 是 df_b 的真子集
        return df_2
    elif c_b < c_a:        # df_b 是 df_a 的真子集
        return df_1
    else:                  # 互不包含或完全相等
        if c_b == c_a:
            return df_1
        else:
            # 选列数多的；一样多时按表名字典序随便定一个
            drop_side, keep_side = (df_1, df_2) if len(c_a) <= len(c_b) else (df_2, df_1)
        common = list(drop_side.columns.intersection(keep_side.columns))
        if common:
            drop_side = drop_side.drop(columns=common)
    sql = f"""
    SELECT * FROM drop_side
    {how} JOIN drop_side USING (date, instrument)
    """
    data = dai.query(sql).df()
    data = _delete_extra_column(data)
    return data

def m_avg(data, feature_name, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_avg({feature_name}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def c_avg(data, feature_name):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, c_avg({feature_name})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def m_std(data, feature_name, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_stddev({feature_name}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def c_std(data, feature_name):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, c_std({feature_name})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def m_corr(data, feature_name_1, feature_name_2, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_corr({feature_name_1}, {feature_name_2}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def m_max(data, feature_name, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_max({feature_name}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def m_min(data, feature_name, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_min({feature_name}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def c_z_normalize(data, feature_name):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, c_zscore({feature_name})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def c_pct_rank(data, feature_name):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, c_pct_rank({feature_name})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def m_pct_rank(data, feature_name, N):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, m_pct_rank({feature_name}, {N})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def add(data, feature_name_1, feature_name_2):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, {feature_name_1} + {feature_name_2}
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def divide(data, feature_name_1, feature_name_2):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data.copy()
    sql = f"""
    SELECT {columns}, {feature_name_1} / {feature_name_2}
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def times(data, feature_name_1, feature_name_2):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, {feature_name_1} * {feature_name_2}
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def sub(data, feature_name_1, feature_name_2):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, {feature_name_1} - {feature_name_2}
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

def abs_func(data, feature_name):
    columns = list(data.columns)
    columns = ', '.join(columns)
    df = data
    sql = f"""
    SELECT {columns}, abs({feature_name})
    FROM df
    """
    data = dai.query(sql).df()
    data.columns = [i.replace('"', '') for i in list(data.columns)]
    data = _delete_extra_column(data)
    return data

tool_map = {
    "get_data": get_data, 
    "data_filter": data_filter, 
    "merge_data": merge_data, 
    "m_avg": m_avg, 
    "c_avg": c_avg, 
    "m_std": m_std, 
    "c_std": c_std, 
    "m_corr": m_corr, 
    "m_max": m_max, 
    "m_min": m_min, 
    "c_z_normalize": c_z_normalize, 
    "c_pct_rank": c_pct_rank, 
    "m_pct_rank": m_pct_rank, 
    "add": add, 
    "divide": divide, 
    "times": times, 
    "sub": sub, 
    "abs_func": abs_func
}

tool_description = """
get_data: 用于抽取特定时间段的因子数据的工具, 包含三个参数, feature_name 表示需要抽取的字段名列表, start_date表示开始日期, end_date表示结束日期, 日期格式形如'2020-01-01'这样的; 
data_filter: 用于筛选数据, 筛选出指定特征大于或者等于或者小于某个阈值的数据, 包含四个参数, data 表示需要筛选的原表数据, feature_name 表示需要筛选的字段, direction 表示判断的方向(目前只支持'>'、'='、'<'三种符号), threshold 表示阈值; 
merge_data: 用于连接两个数据表的工具, 包含三个参数, data_1 表示第一张需要拼接的表格, data_2 表示第二张需要拼接的表格, how 表示拼接方式(可选的拼接方式为: inner、right、left, 分别表示内连接、右连接、左连接); 
m_avg: 时序平均算子, 用于计算一段时期内的平均水平(典型应用为均线), 包含三个参数, data 表示需要参与时序运算的表格, feature_name 表示需要参与时序平均运算的字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的平均数; 
c_avg: 截面平均算子, 用于计算时间截面上的平均水平, 包含两个参数, data 表示需要参与截面均值运算的表格, feature_name 表示需要参与截面平均运算的字段名; 
m_std: 时序标准差算子, 用于计算一段时期内的标准差, 包含三个参数, data 表示需要参与时序运算的表格, feature_name 表示需要参与时序标准差运算的字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的标准差; 
c_std: 截面标准差算子, 用于计算时间截面上的标准差, 包含两个参数, data 表示需要参与截面标准差运算的表格, feature_name 表示需要参与截面标准差运算的字段名; 
m_corr: 时序相关性算子, 用于计算一段时间内两个因子的相关性, 包含四个参数, data表示需要参与时序运算的表格, feature_name_1 和 feature_name_2 都表示需要计算时序相关性的两个字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的相关系数; 
m_max: 时序最大值算子, 用于计算一段时期内的最大值, 包含三个参数, data 表示需要参与时序运算的表格, feature_name 表示需要参与时序最大值运算的字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的最大值; 
m_min: 时序最小值算子, 用于计算一段时期内的最小值, 包含三个参数, data 表示需要参与时序运算的表格, feature_name 表示需要参与时序最小值运算的字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的最小值; 
c_z_normalize: 截面标准化算子, 用于对截面上的因子值进行标准化, 包含两个参数, data 表示需要参与标准化的表格数据, feature_name 表示需要参与截面标准化运算的字段名; 
c_pct_rank: 截面排名百分比算子, 用于计算指定特征的截面排名百分比的算子, 包含两个参数, data 表示需要参与截面排名百分比运算的表格, feature_name 表示需要参与截面排名百分比的字段名; 
m_pct_rank: 时序排名百分比算子, 用于计算指定特征的时序排名百分比算子, 包含三个参数, data 表示需要参与时序排名百分比运算的表格, feature_name 表示需要参与时序排名百分比的字段名, N 表示算子需要计算当期往前取N-1期(包括当期)序列的时序排名百分比; 
add: 加法算子, 用于计算两个因子的加总, 包含三个参数, data 表示需要参与加法运算的表格, feature_name_1 和 feature_name_2表示参与加法的两个因子名, 当然也可以传入数值!
divide: 除法算子, 用于计算两个因子的比值, 包含三个参数, data 表示需要参与除法运算的表格, feature_name_1 和 feature_name_2 表示参与除法的两个因子名, 当然也可以传入数值! feature_name_2 是除数, 除数不能为0; 
times: 乘法算子, 用于计算两个因子的乘积, data 表示需要参与乘法运算的表格, feature_name_1 和 feature_name_2 表示参与乘法的两个因子名, 当然也可以传入数值; 
sub: 减法算子, 用于计算两个因子的差值, data 表示需要参与减法运算的表格, feature_name_1 和 feature_name_2 表示参与减法的两个因子名, 当然也可以传入数值, 实用小技巧: feature_name_1 取 0 即可得到 feature_name_2 的相反数; 
abs_func: 绝对值算子, 用于得到因子的绝对值, data 表示需要参与运算的表格, feature_name 表示需要计算绝对值的字段名. 
"""
