224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
"""
|
||
社交网络锚点选择系统
|
||
|
||
主要功能:
|
||
1. 从社交网络传播数据中处理级联事件(cascades)
|
||
2. 基于用户参与事件的数量和奖励函数R,选择最优的社交传感器(anchors)
|
||
3. 输出选中的锚点到文件
|
||
|
||
输入:级联事件文件(每行代表一个信息传播路径)
|
||
输出:选中的锚点ID列表
|
||
"""
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
import copy
|
||
import argparse
|
||
|
||
|
||
def find_max_indices_numpy(L_dict):
|
||
"""使用numpy高效找到字典中值最大的键
|
||
|
||
参数:
|
||
L_dict: {用户ID: 参与事件数}的字典
|
||
|
||
返回:
|
||
参与事件数最多的用户ID数组
|
||
"""
|
||
keys_arr = np.array(list(L_dict.keys()))
|
||
values_arr = np.array(list(L_dict.values()))
|
||
max_val = max(L_dict.values()) # 找到最大值
|
||
|
||
indices = np.where(values_arr == max_val)[0] # 找到所有等于最大值的索引
|
||
return keys_arr[indices] # 返回对应的键
|
||
|
||
|
||
def R(event_dict):
|
||
"""奖励函数,计算事件字典的负值和
|
||
|
||
参数:
|
||
event_dict: {事件ID: 激活时间}的字典
|
||
|
||
返回:
|
||
负的事件数总和(用于最小化监测的事件数)
|
||
"""
|
||
return -sum(event_dict.values())
|
||
|
||
|
||
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
|
||
"""
|
||
贪心算法选择社交传感器
|
||
|
||
参数:
|
||
R: 奖励函数
|
||
L_dict: {用户ID: 参与事件数}
|
||
user_event_dict: {用户ID: {事件ID: 激活时间}}
|
||
event_user_dict: {事件ID: [参与用户列表]}
|
||
b: 预算(最大锚点数)
|
||
|
||
返回:
|
||
选中的锚点集合
|
||
"""
|
||
print("开始选择社交传感器...")
|
||
user_event_dict = copy.deepcopy(user_event_dict) # 避免修改原始数据
|
||
|
||
V = list(L_dict.keys()) # 所有用户
|
||
A = set() # 已选锚点
|
||
f = lambda A: len(A) # 当前锚点数
|
||
all_cas = 0 # 总共覆盖的事件数
|
||
|
||
# 主循环:直到选够锚点或所有用户都被考虑
|
||
while any(s not in A for s in V) and f(A) < b:
|
||
# 1. 找到当前参与事件最多的用户
|
||
indices_np = find_max_indices_numpy(L_dict)
|
||
TAR_set = set(indices_np)
|
||
|
||
# 2. 评估候选用户的增量收益
|
||
delta = {s: float('inf') for s in TAR_set} # 收益字典
|
||
cur = {s: False for s in TAR_set} # 是否已评估
|
||
|
||
c_star = []
|
||
while True:
|
||
# 选择当前收益最大的用户
|
||
s_star = max(delta, key=delta.get)
|
||
c_star = user_event_dict[s_star]
|
||
|
||
if cur[s_star]: # 如果已评估过,选择该用户
|
||
A.add(s_star)
|
||
break
|
||
else: # 否则计算其收益
|
||
delta[s_star] = R(c_star)
|
||
cur[s_star] = True
|
||
|
||
# 3. 更新统计信息
|
||
all_cas += len(c_star)
|
||
|
||
# 4. 更新数据结构:移除已被覆盖的事件
|
||
for cas_id in list(c_star.keys()):
|
||
# 从事件参与用户中移除该事件
|
||
for v in event_user_dict[cas_id]:
|
||
if v in L_dict:
|
||
L_dict[v] -= 1 # 减少用户参与事件计数
|
||
user_event_dict[v].pop(cas_id, None) # 移除事件
|
||
|
||
print(f"添加社交传感器,当前数量: {len(A)}")
|
||
print(f"锚点ID: {s_star}")
|
||
print(f"总覆盖事件数: {all_cas}")
|
||
print(f"候选集大小: {len(TAR_set)}")
|
||
|
||
print(f"选择完成! 共选择 {len(A)} 个社交传感器")
|
||
return A
|
||
|
||
|
||
def handle_cas(filename, obs_time=-1):
|
||
"""
|
||
处理级联事件文件,构建数据结构
|
||
|
||
参数:
|
||
filename: 级联事件文件路径
|
||
obs_time: 观察时间阈值(-1表示无限制)
|
||
|
||
返回:
|
||
L_dict: {用户ID: 参与事件数}
|
||
user_event_dict: {用户ID: {事件ID: 标准化激活时间}}
|
||
event_user_dict: {事件ID: [参与用户列表]}
|
||
"""
|
||
print("开始处理级联事件...")
|
||
user_event_dict = defaultdict(dict)
|
||
event_user_dict = defaultdict(list)
|
||
cascades_total = 0
|
||
|
||
with open(filename) as file:
|
||
for line in file:
|
||
cascades_total += 1
|
||
parts = line.split(',')
|
||
cascade_id = parts[0] # 事件ID
|
||
activation_times = {} # 用户激活时间
|
||
|
||
paths = parts[1:] # 传播路径
|
||
|
||
# 计算事件的最早和最晚时间
|
||
t_max, t_min = 0, float('inf')
|
||
for p in paths:
|
||
nodes = p.split(':')[0].split('/')
|
||
time_now = int(p.split(':')[1])
|
||
|
||
t_max = max(t_max, time_now)
|
||
t_min = min(t_min, time_now)
|
||
|
||
node_id = int(nodes[-1]) # 最后一个节点是激活用户
|
||
|
||
# 如果超过观察时间则跳过
|
||
if obs_time != -1 and time_now > obs_time:
|
||
continue
|
||
|
||
# 记录用户的最早激活时间
|
||
if node_id in activation_times:
|
||
activation_times[node_id] = min(time_now, activation_times[node_id])
|
||
else:
|
||
activation_times[node_id] = time_now
|
||
|
||
# 构建数据结构
|
||
for user_id, act_time in activation_times.items():
|
||
event_user_dict[cascade_id].append(user_id)
|
||
|
||
# 计算标准化激活时间(0-1之间)
|
||
if t_max > t_min:
|
||
norm_time = (act_time - t_min) / (t_max - t_min)
|
||
user_event_dict[user_id][cascade_id] = norm_time
|
||
|
||
# 计算每个用户参与的事件数
|
||
L_dict = {user: len(events) for user, events in user_event_dict.items()}
|
||
|
||
print(f"处理完成! 共 {len(L_dict)} 个用户")
|
||
return L_dict, user_event_dict, event_user_dict
|
||
|
||
|
||
def generate_anchors(input_file, output_file, anchor_budget, obs_time=-1):
|
||
"""
|
||
生成锚点的主函数
|
||
|
||
参数:
|
||
input_file: 输入文件路径
|
||
output_file: 输出文件路径
|
||
anchor_budget: 锚点预算数量
|
||
obs_time: 观察时间阈值
|
||
"""
|
||
# 1. 处理数据
|
||
L_dict, user_event_dict, event_user_dict = handle_cas(input_file, obs_time)
|
||
|
||
# 2. 计算最大锚点数(不超过总用户数的2%)
|
||
num_nodes = len(L_dict)
|
||
max_anchor_num = int(num_nodes * 0.02)
|
||
|
||
if anchor_budget > max_anchor_num:
|
||
print(f"最大锚点数 {max_anchor_num},预算调整为 {max_anchor_num}")
|
||
anchor_budget = max_anchor_num
|
||
|
||
# 3. 选择锚点
|
||
A = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_budget)
|
||
|
||
# 4. 保存结果
|
||
with open(output_file, 'w') as file:
|
||
for anchor in A:
|
||
file.write(f"{anchor}\n")
|
||
|
||
|
||
def parse_args():
|
||
"""解析命令行参数"""
|
||
parser = argparse.ArgumentParser(description='社交网络锚点选择参数')
|
||
parser.add_argument('--input_file', default='./dataset_for_anchor.txt',
|
||
type=str, help='级联事件文件路径')
|
||
parser.add_argument('--output_file', default='./anchors.txt',
|
||
type=str, help='锚点输出文件路径')
|
||
parser.add_argument('--anchor_budget', default=100,
|
||
type=int, help='锚点预算数量')
|
||
parser.add_argument('--obs_time', default=-1,
|
||
type=int, help='观察时间阈值(-1表示无限制)')
|
||
return parser.parse_args()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
args = parse_args()
|
||
generate_anchors(args.input_file, args.output_file,
|
||
args.anchor_budget, args.obs_time)
|