mdbs/anchor.py
2025-07-11 11:30:15 +08:00

599 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import defaultdict
import time
import numpy as np
import copy
import argparse
import json
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
import time
from matplotlib import font_manager
import seaborn as sns
import os
import pandas as pd
def find_max_indices_numpy(L_dict):
keys_arr = np.array(list(L_dict.keys()))
values_arr = np.array(list(L_dict.values()))
max_val_from_dict = max(L_dict.values()) # 或者 np.max(values_arr)
indices = np.where(values_arr == max_val_from_dict)[0]
max_keys_np_way = keys_arr[indices]
return max_keys_np_way, max_val_from_dict
def R(event_dict):
mean = 0
for _,v in event_dict.items():
mean += -v
mean /= len(event_dict)
return mean
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
"""
Parameters:
- R: reward function, R(A) returns a numeric value
- L_dict: dict of node -> integer (event num)
- user_event_dict: user -> event -> active_time
- event_user_dict: event -> user_list
- b: budget (numeric)
Returns:
- A: selected set of social sensors
"""
print("Select social sensors begin!")
user_event_dict_ = copy.deepcopy(user_event_dict)
L_dict_ = copy.deepcopy(L_dict)
V = list(L_dict_.keys())
A = []
f = lambda A: len(A)
all_cas = 0
while any(s not in A for s in V) and f(A) < b:
indices_np, max_val_from_dict = find_max_indices_numpy(L_dict_)
if max_val_from_dict == 0:
print("no extra node in cas!!!!!")
break
TAR_set = set(indices_np)
delta = {}
cur = {}
for s in TAR_set:
delta[s] = float('inf')
cur[s] = False
c_star = []
while True:
s_star = max(delta, key=delta.get)
c_star = user_event_dict_[s_star]
if cur[s_star] == True:
if delta[s_star] >= args.anchor_screen:
A.append(s_star)
break
else:
delta[s_star] = R(c_star)
cur[s_star] = True
if delta[s_star] < args.anchor_screen:
L_dict_[s_star] = 0
continue
all_cas += len(c_star)
for cas_id in list(c_star.keys()):
uc = event_user_dict[cas_id]
for v in uc:
if v in L_dict_.keys():
L_dict_[v] -= 1
_ = user_event_dict_[v].pop(cas_id)
print(f"Add a social sensor, sensors num is {len(A)}")
print(f"Anchor id: {s_star}")
print(f"all_cas is {all_cas}")
print(f"TAR_set size: {len(TAR_set)}")
if all_cas >= len(event_user_dict):
user_event_dict_ = copy.deepcopy(user_event_dict)
L_dict_ = copy.deepcopy(L_dict)
for s in A:
L_dict_[s] = 0
all_cas = 0
print(f"All cas has been perception, add extra node")
# break
print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
return A, all_cas
def handle_cas(args):
filename, result_path, obs_time = args.input_file, args.result_path, args.obs_time
print("Handle cascade begin!")
user_event_dict = defaultdict(dict)
event_user_dict = defaultdict(list)
cascades_total = 0
# t_min_cas = -1
t_min_all = args.t_min
with open(filename) as file:
for line in file:
cascades_total += 1
if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
break
parts = line.split(',')
cascade_id = parts[0]
activation_times = {}
paths = parts[1:]
t_max = 0
t_min = float('inf')
for p in paths:
# observed adoption/participant
n_t = p.split(':')
if len(n_t) < 2:
continue
nodes = n_t[0].split('/')
time_now = int(n_t[1])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node = nodes[-1]
node_id = node
if time_now > obs_time and obs_time != -1:
continue
if node_id in activation_times.keys():
activation_times[node_id] = min(time_now, activation_times[node_id])
else:
activation_times[node_id] = time_now
print(f"cascade_id:{cascade_id}, t_min:{t_min}")
if len(activation_times) < args.cas_min_length:
print(f"Too short Cas! Cas_id:{cascade_id}")
continue
if t_min > args.t_max or t_min < args.t_min:
print(f"Cas is not in proper time! Cas_id:{cascade_id}")
continue
for k,v in activation_times.items():
event_user_dict[cascade_id].append(k)
if t_max > t_min:
user_event_dict[k][cascade_id] = (v-t_min)/(t_max-t_min)
# print(f"t_min_cas: {t_min_cas}")
L_dict = {}
for k,v in user_event_dict.items():
L_dict[k] = len(v.keys())
print(f"Handle cascade file finish! Users num is {len(L_dict)}")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, f'user_event_{args.obs_cas_num}.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(user_event_dict, f, indent=4, ensure_ascii=False)
print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
return L_dict, user_event_dict, event_user_dict
def generate_anchors(args):
input_file, result_path, anchor_budget, obs_time = args.input_file, args.result_path, args.anchor_budget, args.obs_time
L_dict, user_event_dict, event_user_dict = handle_cas(args)
num_nodes = len(L_dict.keys())
anchor_num = 0
if anchor_budget > 0.02:
max_anchor_num = int(num_nodes*0.02)
print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
anchor_num = max_anchor_num
else:
anchor_num = int(num_nodes*anchor_budget)
if args.anchor_num != -1:
anchor_num = args.anchor_num
print(f"anchor_num is set to {anchor_num}")
# anchor_num = 10
A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
os.makedirs(result_path, exist_ok=True)
output_file = os.path.join(result_path, f'anchors_{args.anchor_budget}.txt')
with open(output_file, 'w') as file:
for item in A:
file.write(f"{item}\n")
return A, all_cas, user_event_dict
# Logistic 函数定义
def logistic(t, K, r, t0):
return K / (1 + np.exp(-r * (t - t0)))
def find_inflection_logistic(t_list, c_list):
t_array = np.array(t_list)
c_array = np.array(c_list)
# 时间归一化到 [0, 1]
t_min = t_array.min()
t_max = t_array.max()
if len(set(c_array)) < 3 or max(c_array) <= 1 or t_max - t_min == 0:
return None
t_array_normalized = (t_array - t_min) / (t_max - t_min)
try:
p0 = [max(c_array) if max(c_array) > 0 else 1.0, 1.0, 0.5] # 确保 K 初始值不为0如果c_array全是0
# 检查 c_array 是否有效
if max(c_array) <= 0: # 如果c_array都是0或负数也无法拟合S曲线
print("警告: 累计值 (c_list) 没有正增长无法拟合S曲线。")
return None
# 调整 bounds如果 K 的下限是0max(c_array)也可能是0可能导致问题
# bounds = ([0, 0, 0], [np.inf, 10, 1]) # K r t0
# K > 0 (所以下限设为很小的正数,或基于数据)
# r > 0 (增长率)
# t0 在 [0, 1] 范围内
bounds = ([1e-6, 1e-6, 0], [np.inf, 10, 1]) # K, r 应该为正
popt, _ = curve_fit(logistic, t_array_normalized, c_array, p0=p0, bounds=bounds, maxfev=10000)
K, r, t0 = popt
return t0
except (RuntimeError, ValueError):
print(f"警告: 曲线拟合失败或参数无效")
return None
# 获取所有事件的爆发时间
def get_t0(args):
input_file, result_path = args.input_file, args.result_path
print("To get All Cas' T0")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, 'prepare.json')
prepare = {}
if os.path.exists(json_file):
print(f"文件 {json_file} 存在。")
with open(json_file, 'r', encoding='utf-8') as f:
prepare = json.load(f)
return prepare
t0_dict = {}
num_time_points = 50 # 你希望在平均趋势图上采样的点数,可以调整
common_normalized_timeline = np.linspace(0, 1, num_time_points)
nodes_in_bins = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
interpolated_counts_all_cascades = []
cascades_total = 0
t0_mean = 0
with open(input_file) as file:
for line in file:
cascades_total += 1
if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
break
parts = line.split(',')
cascade_id = parts[0]
paths = parts[1:]
t_max = 0
t_min = float('inf')
cas_set = set()
times = []
counts = []
for p in paths:
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node = nodes[-1]
node_id = node
cas_set.add(node_id)
times.append(time_now)
counts.append(len(cas_set))
sorted_times = sorted(times)
if t_max - t_min <= 0:
continue
nodes_in_bin = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
for i in range(len(times)):
participation_time = (sorted_times[i] - t_min)/(t_max - t_min)
sorted_times[i] = participation_time
bin_index = np.digitize(participation_time, common_normalized_timeline, right=False) - 1
if bin_index == -1 and participation_time == common_normalized_timeline[0]:
bin_index = 0
elif bin_index == len(common_normalized_timeline) - 1:
if participation_time == common_normalized_timeline[-1]:
bin_index = len(common_normalized_timeline) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bin):
nodes_in_bins[bin_index] +=1
nodes_in_bin[bin_index] += 1
interpolated_counts = np.interp(common_normalized_timeline,
sorted_times,
counts)
estimated_increments_between_common_times = np.diff(interpolated_counts)
interpolated_counts_all_cascades.append(estimated_increments_between_common_times)
t0 = find_inflection_logistic(sorted_times, counts)
if t0 is None:
continue
t0_dict[cascade_id] = [t0, (t_max - t_min)/3600]
t0_mean += t0
t0_mean /= len(t0_dict)
stacked_interpolated_counts = np.vstack(interpolated_counts_all_cascades).tolist()
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
all_counts = (counts_0 / counts_0.max()).tolist()
prepare['t0'] = t0_dict
prepare['t0_mean'] = t0_mean
prepare['interpolated_counts_sum'] = stacked_interpolated_counts
prepare['interpolated_counts_step_avg'] = all_counts
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(prepare, f, indent=4, ensure_ascii=False)
print(f"t0_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
return prepare
def analyse_result(node_list, user_event_dict, prepare, args, set_name):
result_path = args.result_path
t0_dict = prepare['t0']
t0_mean = prepare['t0_mean']
anchors_activation_times_before_t0_ = defaultdict(list)
anchors_activation_times_before_t0 = defaultdict(list)
anchors_activation_times_ = defaultdict(list)
anchors_activation_times = defaultdict(list)
cas_set = set()
for k,v in user_event_dict.items():
if k in node_list:
for cas_id, ac_time in v.items():
cas_set.add(cas_id)
if cas_id not in t0_dict:
continue
anchors_activation_times_before_t0_[k].append(ac_time - t0_dict[cas_id][0])
anchors_activation_times_before_t0[k].append((ac_time - t0_dict[cas_id][0])*t0_dict[cas_id][1])
anchors_activation_times_[k].append(ac_time)
anchors_activation_times[k].append(ac_time*t0_dict[cas_id][1])
print(f"感知到的级联数量:{len(cas_set)}")
anchors_avg_activation_time_before_t0 = dict()
anchors_avg_activation_time = dict()
for node_id, times in anchors_activation_times_before_t0.items():
if times:
anchors_avg_activation_time_before_t0[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time_before_t0[node_id] = None
for node_id, times in anchors_activation_times.items():
if times:
anchors_avg_activation_time[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time[node_id] = None
anchors_avg_activation_time_before_t0_ = dict()
anchors_avg_activation_time_ = dict()
for node_id, times in anchors_activation_times_before_t0_.items():
if times:
anchors_avg_activation_time_before_t0_[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time_before_t0_[node_id] = None
for node_id, times in anchors_activation_times_.items():
if times:
anchors_avg_activation_time_[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time_[node_id] = None
anchors_total_avg_time_before_t0 = sum(t for t in anchors_avg_activation_time_before_t0.values() if t is not None) / len(anchors_avg_activation_time_before_t0)
anchors_total_avg_time = sum(t for t in anchors_avg_activation_time.values() if t is not None) / len(anchors_avg_activation_time)
anchors_total_avg_time_before_t0_ = sum(t for t in anchors_avg_activation_time_before_t0_.values() if t is not None) / len(anchors_avg_activation_time_before_t0_)
anchors_total_avg_time_ = sum(t for t in anchors_avg_activation_time_.values() if t is not None) / len(anchors_avg_activation_time_)
result_path = os.path.join(result_path, f'{set_name}')
os.makedirs(result_path, exist_ok=True)
anchor_event_dict = {}
for anchor in node_list:
anchor_event_dict[anchor] = user_event_dict[anchor]
anchor_event_file = os.path.join(result_path, f'anchor_event_{args.anchor_budget}.json')
try:
with open(anchor_event_file, 'w', encoding='utf-8') as f:
json.dump(anchor_event_dict, f, indent=4, ensure_ascii=False)
print(f"anchor_event_dictt已成功保存到 JSON 文件: {anchor_event_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{anchor_event_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
anylyze_result = {}
anylyze_result['anchors_act_cas_num'] = len(cas_set)
anylyze_result['act_cas'] = list(cas_set)
anylyze_result['anchors_total_avg_time_before_t0'] = anchors_total_avg_time_before_t0
anylyze_result['anchors_avg_activation_time_before_t0'] = anchors_avg_activation_time_before_t0
anylyze_result['anchors_total_avg_time_before_t0_'] = anchors_total_avg_time_before_t0_
anylyze_result['anchors_avg_activation_time_before_t0_'] = anchors_avg_activation_time_before_t0_
anylyze_result['anchors_total_avg_time'] = anchors_total_avg_time
anylyze_result['anchors_avg_activation'] = anchors_avg_activation_time
anylyze_result['anchors_total_avg_time_'] = anchors_total_avg_time_
anylyze_result['anchors_avg_activation_'] = anchors_avg_activation_time_
json_file = os.path.join(result_path, f'anylyze_result_{args.anchor_budget}.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(anylyze_result, f, indent=4, ensure_ascii=False)
print(f"anylyze_result已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
print("开始绘制激活时间图 ")
# --- 绘制激活时间图 ---
important_times = [t for t in anchors_avg_activation_time_before_t0_.values() if t is not None]
# 然后继续绘制图表
# 假设已经有 important_times 和 random_times
plt.figure(figsize=(10, 6))
sns.kdeplot(important_times, label='anchors', fill=True, color='red', linewidth=2)
plt.title('Anchors activate time')
plt.xlabel('Time')
plt.ylabel('P')
plt.legend()
plt.grid(True)
plt.tight_layout()
fig_file_1 = os.path.join(result_path, f'anchor_act_time_{args.anchor_budget}.png')
plt.savefig(fig_file_1, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
print("开始绘制趋势图 ")
# --- 绘制趋势图 ---
min_time = 0.0
max_time = 1 # 假设最大观察时间是 0.5,你可以根据你的数据调整
bin_width = 0.01
num_bins = int(np.ceil((max_time - min_time) / bin_width)) # 向上取整计算bin的数量
# 创建时间段的边界
# 例如:[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
time_bin_edges = np.linspace(min_time, max_time, num_bins + 1)
nodes_in_bins = {i: 0 for i in range(len(time_bin_edges) - 1)}
for node, participation_time in anchors_avg_activation_time_.items():
bin_index = np.digitize(participation_time, time_bin_edges, right=False) - 1
if bin_index == -1 and participation_time == time_bin_edges[0]:
bin_index = 0
elif bin_index == len(time_bin_edges) - 1:
if participation_time == time_bin_edges[-1]:
bin_index = len(time_bin_edges) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bins):
nodes_in_bins[bin_index] += 1
plt.figure(figsize=(10, 6))
bin_labels = []
for i in range(len(time_bin_edges) - 1):
start_time = time_bin_edges[i]
end_time = time_bin_edges[i+1]
# 你可以用区间的字符串表示,或者区间的中心点
bin_labels.append(f"[{start_time:.2f}-{end_time:.2f})")
# bin_labels.append(f"{start_time:.1f}-") # 更简洁的标签,只显示起始
# 获取每个 bin 的计数值
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
counts = (counts_0 / counts_0.max()).tolist()
# X 轴的位置 (用于条形图)
x_positions = time_bin_edges[:-1] + bin_width/2
# --- 使用 plt.axvline() 绘制垂直红线 ---
plt.axvline(x=t0_mean, # 线的 x 坐标
color='red', # 线的颜色
linestyle='--', # 线型 (例如 'solid', '--', '-.', ':')
linewidth=2, # 线的宽度
label=f'Average outbreak time = {t0_mean}')
fig_file_2 = os.path.join(result_path, f'trend_{args.anchor_budget}.png')
plt.savefig(fig_file_2, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
return
def parse_args():
parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
A, all_cas, user_event_dict = generate_anchors(args)
#===========分析结果=================
print("开始分析结果")
prepare = get_t0(args)
set_name = 'anchor'
analyse_result(A, user_event_dict, prepare, args, set_name)
# # # #===========读取其他算法挖掘重要节点=================
# with open(args.key_node_file, 'r') as file:
# # 使用列表推导式strip() 去掉每行的换行符,然后将每行转换为整数
# # 过滤掉可能存在的空行或无效数据
# key_node_list = [line.strip() for line in file if line.strip().isdigit()]
# key_node_set = set(key_node_list[:len(A)])
# json_file = os.path.join(args.result_path, 'user_event.json')
# user_event_dict = {}
# if os.path.exists(json_file):
# print(f"文件 {json_file} 存在。")
# with open(json_file, 'r', encoding='utf-8') as f:
# user_event_dict = json.load(f)
# # prepare = get_t0(args)
# set_name_d = 'degree'
# analyse_result(key_node_set, user_event_dict, prepare, args, set_name_d)