mdbs/anchor.py
2025-05-28 15:45:23 +08:00

510 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import defaultdict
import time
import numpy as np
import copy
import argparse
import json
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import numpy as np
import time
from matplotlib import font_manager
import seaborn as sns
import os
def find_max_indices_numpy(L_dict):
keys_arr = np.array(list(L_dict.keys()))
values_arr = np.array(list(L_dict.values()))
max_val_from_dict = max(L_dict.values()) # 或者 np.max(values_arr)
indices = np.where(values_arr == max_val_from_dict)[0]
max_keys_np_way = keys_arr[indices]
return max_keys_np_way
def R(event_dict):
sum = 0
for _,v in event_dict.items():
sum += -v
return sum
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
"""
Parameters:
- R: reward function, R(A) returns a numeric value
- L_dict: dict of node -> integer (event num)
- user_event_dict: user -> event -> active_time
- event_user_dict: event -> user_list
- b: budget (numeric)
Returns:
- A: selected set of social sensors
"""
print("Select social sensors begin!")
user_event_dict = copy.deepcopy(user_event_dict)
V = list(L_dict.keys())
A = set()
f = lambda A: len(A)
all_cas = 0
while any(s not in A for s in V) and f(A) < b:
indices_np = find_max_indices_numpy(L_dict)
TAR_set = set(indices_np)
delta = {}
cur = {}
for s in TAR_set:
delta[s] = float('inf')
cur[s] = False
c_star = []
while True:
s_star = max(delta, key=delta.get)
c_star = user_event_dict[s_star]
if cur[s_star] == True:
A.add(s_star)
break
else:
delta[s_star] = R(c_star)
cur[s_star] = True
all_cas += len(c_star)
for cas_id in list(c_star.keys()):
uc = event_user_dict[cas_id]
for v in uc:
if v in L_dict.keys():
L_dict[v] -= 1
_ = user_event_dict[v].pop(cas_id)
print(f"Add a social sensor, sensors num is {len(A)}")
print(f"Anchor id: {s_star}")
print(f"all_cas is {all_cas}")
print(f"TAR_set size: {len(TAR_set)}")
print(f"Select social sensors finish! Get social sensors, num: {len(A)}")
return A, all_cas
def handle_cas(args):
filename, obs_time = args.input_file, args.obs_time
print("Handle cascade begin!")
user_event_dict = defaultdict(dict)
event_user_dict = defaultdict(list)
cascades_total = 0
with open(filename) as file:
for line in file:
cascades_total += 1
if cascades_total > args.max_cas_num and args.max_cas_num != -1:
break
parts = line.split(',')
cascade_id = parts[0]
activation_times = {}
paths = parts[1:]
t_max = 0
t_min = float('inf')
for p in paths:
# observed adoption/participant
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node = nodes[-1]
node_id = int(node)
if time_now > obs_time and obs_time != -1:
continue
if node_id in activation_times.keys():
activation_times[node_id] = min(time_now, activation_times[node_id])
else:
activation_times[node_id] = time_now
for k,v in activation_times.items():
event_user_dict[cascade_id].append(k)
if t_max > t_min:
user_event_dict[k][cascade_id] = (v-t_min)/(t_max-t_min)
L_dict = {}
for k,v in user_event_dict.items():
L_dict[k] = len(v.keys())
print(f"Handle cascade file finish! Users num is {len(L_dict)}")
return L_dict, user_event_dict, event_user_dict
def generate_anchors(args):
result_path, anchor_budget = args.result_path, args.anchor_budget
L_dict, user_event_dict, event_user_dict = handle_cas(args)
num_nodes = len(L_dict.keys())
anchor_num = 0
if anchor_budget > 0.02:
max_anchor_num = int(num_nodes*0.02)
print(f"Max anchor num is {max_anchor_num}, anchor_budget is set to {max_anchor_num}")
anchor_num = max_anchor_num
else:
anchor_num = int(num_nodes*anchor_budget)
A, all_cas = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_num)
os.makedirs(result_path, exist_ok=True)
output_file = os.path.join(result_path, f'anchors_{args.anchor_budget}.txt')
with open(output_file, 'w') as file:
for item in A:
file.write(f"{item}\n")
return A, all_cas, user_event_dict
# Logistic 函数定义
def logistic(t, K, r, t0):
return K / (1 + np.exp(-r * (t - t0)))
def find_inflection_logistic(t_list, c_list):
t_array = np.array(t_list)
c_array = np.array(c_list)
# 时间归一化到 [0, 1]
t_min = t_array.min()
t_max = t_array.max()
if len(set(c_array)) < 3 or max(c_array) <= 1 or t_max - t_min == 0:
return None
t_array_normalized = (t_array - t_min) / (t_max - t_min)
try:
p0 = [max(c_array) if max(c_array) > 0 else 1.0, 1.0, 0.5] # 确保 K 初始值不为0如果c_array全是0
# 检查 c_array 是否有效
if max(c_array) <= 0: # 如果c_array都是0或负数也无法拟合S曲线
print("警告: 累计值 (c_list) 没有正增长无法拟合S曲线。")
return None
# 调整 bounds如果 K 的下限是0max(c_array)也可能是0可能导致问题
# bounds = ([0, 0, 0], [np.inf, 10, 1]) # K r t0
# K > 0 (所以下限设为很小的正数,或基于数据)
# r > 0 (增长率)
# t0 在 [0, 1] 范围内
bounds = ([1e-6, 1e-6, 0], [np.inf, 10, 1]) # K, r 应该为正
popt, _ = curve_fit(logistic, t_array_normalized, c_array, p0=p0, bounds=bounds, maxfev=10000)
K, r, t0 = popt
return t0
except (RuntimeError, ValueError):
print(f"警告: 曲线拟合失败或参数无效")
return None
# 获取所有事件的爆发时间
def get_t0(args):
input_file, result_path = args.input_file, args.result_path
print("To get All Cas' T0")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, 'prepare.json')
prepare = {}
if os.path.exists(json_file):
print(f"文件 {json_file} 存在。")
with open(json_file, 'r', encoding='utf-8') as f:
prepare = json.load(f)
return prepare
t0_dict = {}
num_time_points = 50 # 你希望在平均趋势图上采样的点数,可以调整
common_normalized_timeline = np.linspace(0, 1, num_time_points)
nodes_in_bins = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
interpolated_counts_all_cascades = []
cascades_total = 0
t0_mean = 0
with open(input_file) as file:
for line in file:
cascades_total += 1
if cascades_total > args.max_cas_num and args.max_cas_num != -1:
break
parts = line.split(',')
cascade_id = parts[0]
paths = parts[1:]
t_max = 0
t_min = float('inf')
cas_set = set()
times = []
counts = []
for p in paths:
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
if time_now > t_max:
t_max = time_now
if time_now < t_min:
t_min = time_now
node = nodes[-1]
node_id = int(node)
cas_set.add(node_id)
times.append(time_now)
counts.append(len(cas_set))
sorted_times = sorted(times)
if t_max - t_min <= 0:
continue
nodes_in_bin = {i: 0 for i in range(len(common_normalized_timeline) - 1)}
for i in range(len(times)):
participation_time = (sorted_times[i] - t_min)/(t_max - t_min)
sorted_times[i] = participation_time
bin_index = np.digitize(participation_time, common_normalized_timeline, right=False) - 1
if bin_index == -1 and participation_time == common_normalized_timeline[0]:
bin_index = 0
elif bin_index == len(common_normalized_timeline) - 1:
if participation_time == common_normalized_timeline[-1]:
bin_index = len(common_normalized_timeline) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bin):
nodes_in_bins[bin_index] +=1
nodes_in_bin[bin_index] += 1
interpolated_counts = np.interp(common_normalized_timeline,
sorted_times,
counts)
estimated_increments_between_common_times = np.diff(interpolated_counts)
interpolated_counts_all_cascades.append(estimated_increments_between_common_times)
t0 = find_inflection_logistic(sorted_times, counts)
if t0 is None:
continue
t0_dict[cascade_id] = t0
t0_mean += t0
t0_mean /= len(t0_dict)
stacked_interpolated_counts = np.vstack(interpolated_counts_all_cascades).tolist()
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
all_counts = (counts_0 / counts_0.max()).tolist()
prepare['t0'] = t0_dict
prepare['t0_mean'] = t0_mean
prepare['interpolated_counts_sum'] = stacked_interpolated_counts
prepare['interpolated_counts_step_avg'] = all_counts
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(prepare, f, indent=4, ensure_ascii=False)
print(f"t0_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
return prepare
def analyse_result(node_list, user_event_dict, prepare, args):
result_path = args.result_path
t0_dict = prepare['t0']
t0_mean = prepare['t0_mean']
anchors_activation_times_before_t0 = defaultdict(list)
anchors_activation_times = defaultdict(list)
cas_set = set()
for k,v in user_event_dict.items():
if k in node_list:
for cas_id, ac_time in v.items():
cas_set.add(cas_id)
if cas_id not in t0_dict:
continue
anchors_activation_times_before_t0[k].append(ac_time - t0_dict[cas_id])
anchors_activation_times[k].append(ac_time)
anchors_avg_activation_time_before_t0 = dict()
anchors_avg_activation_time = dict()
for node_id, times in anchors_activation_times_before_t0.items():
if times:
anchors_avg_activation_time_before_t0[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time_before_t0[node_id] = None
for node_id, times in anchors_activation_times.items():
if times:
anchors_avg_activation_time[node_id] = (sum(times) / len(times))
else:
anchors_avg_activation_time[node_id] = None
anchors_total_avg_time_before_t0 = sum(t for t in anchors_avg_activation_time_before_t0.values() if t is not None) / len(anchors_avg_activation_time_before_t0)
anchors_total_avg_time = sum(t for t in anchors_avg_activation_time.values() if t is not None) / len(anchors_avg_activation_time)
os.makedirs(result_path, exist_ok=True)
anylyze_result = {}
anylyze_result['anchors_act_cas_num'] = len(cas_set)
anylyze_result['act_cas'] = list(cas_set)
anylyze_result['anchors_total_avg_time_before_t0'] = anchors_total_avg_time_before_t0
anylyze_result['anchors_avg_activation_time_before_t0'] = anchors_avg_activation_time_before_t0
anylyze_result['anchors_total_avg_time'] = anchors_total_avg_time
anylyze_result['anchors_avg_activation'] = anchors_avg_activation_time
json_file = os.path.join(result_path, f'anylyze_result_{args.anchor_budget}.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(anylyze_result, f, indent=4, ensure_ascii=False)
print(f"anylyze_result已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
print("开始绘制激活时间图 ")
# --- 绘制激活时间图 ---
important_times = [t for t in anchors_avg_activation_time_before_t0.values() if t is not None]
# 然后继续绘制图表
# 假设已经有 important_times 和 random_times
plt.figure(figsize=(10, 6))
sns.kdeplot(important_times, label='anchors', fill=True, color='red', linewidth=2)
plt.title('Anchors activate time')
plt.xlabel('Time')
plt.ylabel('P')
plt.legend()
plt.grid(True)
plt.tight_layout()
fig_file_1 = os.path.join(result_path, f'anchor_act_time_{args.anchor_budget}.png')
plt.savefig(fig_file_1, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
print("开始绘制趋势图 ")
# --- 绘制趋势图 ---
min_time = 0.0
max_time = 1 # 假设最大观察时间是 0.5,你可以根据你的数据调整
bin_width = 0.01
num_bins = int(np.ceil((max_time - min_time) / bin_width)) # 向上取整计算bin的数量
# 创建时间段的边界
# 例如:[0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
time_bin_edges = np.linspace(min_time, max_time, num_bins + 1)
nodes_in_bins = {i: 0 for i in range(len(time_bin_edges) - 1)}
for node, participation_time in anchors_avg_activation_time.items():
bin_index = np.digitize(participation_time, time_bin_edges, right=False) - 1
if bin_index == -1 and participation_time == time_bin_edges[0]:
bin_index = 0
elif bin_index == len(time_bin_edges) - 1:
if participation_time == time_bin_edges[-1]:
bin_index = len(time_bin_edges) - 2
else:
continue
if 0 <= bin_index < len(nodes_in_bins):
nodes_in_bins[bin_index] += 1
plt.figure(figsize=(10, 6))
bin_labels = []
for i in range(len(time_bin_edges) - 1):
start_time = time_bin_edges[i]
end_time = time_bin_edges[i+1]
# 你可以用区间的字符串表示,或者区间的中心点
bin_labels.append(f"[{start_time:.2f}-{end_time:.2f})")
# bin_labels.append(f"{start_time:.1f}-") # 更简洁的标签,只显示起始
# 获取每个 bin 的计数值
counts_0 = np.array([nodes_in_bins[i] for i in sorted(nodes_in_bins.keys())])
counts = (counts_0 / counts_0.max()).tolist()
# X 轴的位置 (用于条形图)
x_positions = time_bin_edges[:-1] + bin_width/2
bars = plt.bar(x_positions, counts,
width=bin_width, # 条形的宽度
color='skyblue', # 条形的颜色
edgecolor='black') # 条形的边框颜色
# bar_index = 0
# for bar in bars:
# yval = bar.get_height()
# if yval > 0: # 只为非零的条形添加文本
# plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.05, # 文本位置微调
# counts_0[bar_index],
# ha='center', # 水平居中
# va='bottom') # 垂直对齐方式
# bar_index +=1
# stacked_interpolated_counts = np.array(prepare['interpolated_counts'])
# stacked_log_counts = stacked_interpolated_counts
# # 沿着级联的维度 (axis=0) 计算平均值
# average_counts_trend = np.mean(stacked_log_counts, axis=0) # (num_time_points,)
# max_counts = average_counts_trend.max()
# std_counts_trend = np.std(stacked_log_counts, axis=0)
# sem_counts_trend = std_counts_trend / np.sqrt( average_counts_trend.shape[0])
# # 计算95%置信区间 (使用1.96作为Z分数)
# confidence_interval_upper = average_counts_trend + 1.96 * sem_counts_trend
# confidence_interval_lower = average_counts_trend - 1.96 * sem_counts_trend
# average_counts_trend_normalized = average_counts_trend / max_counts
# confidence_interval_upper_normalized = confidence_interval_upper / max_counts
# confidence_interval_lower_normalized = confidence_interval_lower / max_counts
# confidence_interval_lower_corrected = np.maximum(0, confidence_interval_lower_normalized)
# common_normalized_timeline = np.linspace(0, 1, average_counts_trend.shape[0])
# plt.plot(common_normalized_timeline, average_counts_trend_normalized, label='Average Participation Trend', color='blue', linewidth=2)
# # (可选) 绘制置信区间或标准差区域
# plt.fill_between(common_normalized_timeline,
# confidence_interval_lower_corrected,
# confidence_interval_upper_normalized,
# color='blue', alpha=0.2, label='95% Confidence Interval')
# counts_step = prepare['interpolated_counts_step_avg']
# common_normalized_timeline = np.linspace(0, 1, len(counts_step))
# plt.plot(common_normalized_timeline, counts_step, label='Average Participation Trend', color='blue', linewidth=2)
# --- 使用 plt.axvline() 绘制垂直红线 ---
plt.axvline(x=t0_mean, # 线的 x 坐标
color='red', # 线的颜色
linestyle='--', # 线型 (例如 'solid', '--', '-.', ':')
linewidth=2, # 线的宽度
label=f'Average outbreak time = {t0_mean}')
fig_file_2 = os.path.join(result_path, f'trend_{args.anchor_budget}.png')
plt.savefig(fig_file_2, # 文件名
dpi=300, # 可选:分辨率 (dots per inch)
bbox_inches='tight',# 可选:尝试裁剪掉空白边缘
)
plt.close()
return
def parse_args():
parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors num')
parser.add_argument('--max_cas_num', default=-1, type=int, help='Cascade num')
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
parser.add_argument('--key_node_json', default='./key_node.json', type=str, help='key_node_json save file')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
A, all_cas, user_event_dict = generate_anchors(args)
#===========分析结果=================
print("开始分析结果")
prepare = get_t0(args)
analyse_result(A, user_event_dict, prepare, args)