from py2neo import Graph,Node
import psycopg2
from commonConfig import POSTGRE_CONFIG,NEO4J_CONFIG
# POSTGRE_CONFIG = {
# 'host' : 'localhost',
# 'port' : '5432',
# 'user' : 'postgres',
# 'password' : '123456',
# 'database' : 'postgres', # 默认数据库,可后续配置
# # 'schema' : 'public' # 默认模式,可后续配置
# }
# NEO4J_CONFIG = {
# 'neo_url' : 'http://localhost:7474/',
# 'username' :'neo4j',
# 'password' : '123456' , # 默认不进行验证
# class Pg2Neo4j:
def get_from_pg(pg_table, neo_label = None, pg_schema ='public',
pg_columns_list = (), neo_properties_list=(),
cover = True, limit = 0,
postgre_config = POSTGRE_CONFIG, neo4j_config= NEO4J_CONFIG):
"""
:param pg_sql_client: pg 连接
:param pg_table:pg表名
:param pg_schema:pg的模式
:param pg_columns_list:选择的pg列
:param neo_properties_list:neo4j属性名,也是原pg表的列名
:param neo_label:写入neo4j的表名/标签label
:param cover: 覆盖写入neo4j
:param limit: 写入neo4j的记录数
:return:
"""
# 获得pg游标
pg_conn = psycopg2.connect(**postgre_config)
cur = pg_conn.cursor()
#获得表的列名
get_column_name = f"SELECT column_name FROM information_schema." \
f"COLUMNS WHERE table_schema = '{pg_schema}' AND TABLE_NAME='{pg_table}'"
print(get_column_name)
# 获得pg表数据
cur.execute(get_column_name)
# 制作表名所在的列表,用来做neo4j节点的属性
columns = cur.fetchall()
pg_columns_tmp=[]
# 没指明pg表名
if not pg_columns_list:
pg_columns_list = ['"'+i[0]+'"'for i in columns]
if not neo_properties_list:
neo_properties = [i[0] for i in columns]
elif len(neo_properties_list) == len(pg_columns_list):
neo_properties = neo_properties_list
else:
raise Exception('The length of neo4j_properties invalid!')
else:
pg_columns_tmp = ['"'+i+'"'for i in pg_columns_list]
if not neo_properties_list:
neo_properties = pg_columns_list
elif len(neo_properties_list) == len(pg_columns_list):
neo_properties = neo_properties_list
else:
raise Exception('The length of neo4j_properties_list and pg_columns_list not equal!')
pg_columns = ",".join(pg_columns_tmp)
if limit ==0:
get_data = f'select {pg_columns} from {pg_table}'
else:
get_data = f'select {pg_columns} from {pg_table} limit{str(limit)}'
cur.execute(get_data)
data =cur.fetchall()
# 开启neo4j事务
graph = Graph(**neo4j_config).begin()
# 如果不指定label,以pg_table的名字作为label
if not neo_label:
neo_label = pg_table
# 是否覆盖写
if cover:
delete_cypher = f'match (a:{neo_label}) delete a'
print(delete_cypher)
graph.run(delete_cypher)
for line in data:
line = [str(word).replace("'","\'").replace("\\","\\\\") for word in line]
node = Node(label = neo_label, **dict(zip(neo_properties, line)))
print(node)
try:
graph.create(node)
except Exception as e:
print(e)
exit(-1)
graph.commit()
cur.close()
if __name__ == '__main__':
# Pg2Neo4j.get_from_pg(pg_table='student', neo_label='student')
get_from_pg(pg_table='student', neo_label='student')