mdbs/anchor.py
2025-05-29 09:49:12 +08:00

224 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
社交网络锚点选择系统
主要功能:
1. 从社交网络传播数据中处理级联事件(cascades)
2. 基于用户参与事件的数量和奖励函数R选择最优的社交传感器(anchors)
3. 输出选中的锚点到文件
输入:级联事件文件(每行代表一个信息传播路径)
输出选中的锚点ID列表
"""
from collections import defaultdict
import numpy as np
import copy
import argparse
def find_max_indices_numpy(L_dict):
"""使用numpy高效找到字典中值最大的键
参数:
L_dict: {用户ID: 参与事件数}的字典
返回:
参与事件数最多的用户ID数组
"""
keys_arr = np.array(list(L_dict.keys()))
values_arr = np.array(list(L_dict.values()))
max_val = max(L_dict.values()) # 找到最大值
indices = np.where(values_arr == max_val)[0] # 找到所有等于最大值的索引
return keys_arr[indices] # 返回对应的键
def R(event_dict):
"""奖励函数,计算事件字典的负值和
参数:
event_dict: {事件ID: 激活时间}的字典
返回:
负的事件数总和(用于最小化监测的事件数)
"""
return -sum(event_dict.values())
def select_social_sensors(R, L_dict, user_event_dict, event_user_dict, b):
"""
贪心算法选择社交传感器
参数:
R: 奖励函数
L_dict: {用户ID: 参与事件数}
user_event_dict: {用户ID: {事件ID: 激活时间}}
event_user_dict: {事件ID: [参与用户列表]}
b: 预算(最大锚点数)
返回:
选中的锚点集合
"""
print("开始选择社交传感器...")
user_event_dict = copy.deepcopy(user_event_dict) # 避免修改原始数据
V = list(L_dict.keys()) # 所有用户
A = set() # 已选锚点
f = lambda A: len(A) # 当前锚点数
all_cas = 0 # 总共覆盖的事件数
# 主循环:直到选够锚点或所有用户都被考虑
while any(s not in A for s in V) and f(A) < b:
# 1. 找到当前参与事件最多的用户
indices_np = find_max_indices_numpy(L_dict)
TAR_set = set(indices_np)
# 2. 评估候选用户的增量收益
delta = {s: float('inf') for s in TAR_set} # 收益字典
cur = {s: False for s in TAR_set} # 是否已评估
c_star = []
while True:
# 选择当前收益最大的用户
s_star = max(delta, key=delta.get)
c_star = user_event_dict[s_star]
if cur[s_star]: # 如果已评估过,选择该用户
A.add(s_star)
break
else: # 否则计算其收益
delta[s_star] = R(c_star)
cur[s_star] = True
# 3. 更新统计信息
all_cas += len(c_star)
# 4. 更新数据结构:移除已被覆盖的事件
for cas_id in list(c_star.keys()):
# 从事件参与用户中移除该事件
for v in event_user_dict[cas_id]:
if v in L_dict:
L_dict[v] -= 1 # 减少用户参与事件计数
user_event_dict[v].pop(cas_id, None) # 移除事件
print(f"添加社交传感器,当前数量: {len(A)}")
print(f"锚点ID: {s_star}")
print(f"总覆盖事件数: {all_cas}")
print(f"候选集大小: {len(TAR_set)}")
print(f"选择完成! 共选择 {len(A)} 个社交传感器")
return A
def handle_cas(filename, obs_time=-1):
"""
处理级联事件文件,构建数据结构
参数:
filename: 级联事件文件路径
obs_time: 观察时间阈值(-1表示无限制)
返回:
L_dict: {用户ID: 参与事件数}
user_event_dict: {用户ID: {事件ID: 标准化激活时间}}
event_user_dict: {事件ID: [参与用户列表]}
"""
print("开始处理级联事件...")
user_event_dict = defaultdict(dict)
event_user_dict = defaultdict(list)
cascades_total = 0
with open(filename) as file:
for line in file:
cascades_total += 1
parts = line.split(',')
cascade_id = parts[0] # 事件ID
activation_times = {} # 用户激活时间
paths = parts[1:] # 传播路径
# 计算事件的最早和最晚时间
t_max, t_min = 0, float('inf')
for p in paths:
nodes = p.split(':')[0].split('/')
time_now = int(p.split(':')[1])
t_max = max(t_max, time_now)
t_min = min(t_min, time_now)
node_id = int(nodes[-1]) # 最后一个节点是激活用户
# 如果超过观察时间则跳过
if obs_time != -1 and time_now > obs_time:
continue
# 记录用户的最早激活时间
if node_id in activation_times:
activation_times[node_id] = min(time_now, activation_times[node_id])
else:
activation_times[node_id] = time_now
# 构建数据结构
for user_id, act_time in activation_times.items():
event_user_dict[cascade_id].append(user_id)
# 计算标准化激活时间(0-1之间)
if t_max > t_min:
norm_time = (act_time - t_min) / (t_max - t_min)
user_event_dict[user_id][cascade_id] = norm_time
# 计算每个用户参与的事件数
L_dict = {user: len(events) for user, events in user_event_dict.items()}
print(f"处理完成! 共 {len(L_dict)} 个用户")
return L_dict, user_event_dict, event_user_dict
def generate_anchors(input_file, output_file, anchor_budget, obs_time=-1):
"""
生成锚点的主函数
参数:
input_file: 输入文件路径
output_file: 输出文件路径
anchor_budget: 锚点预算数量
obs_time: 观察时间阈值
"""
# 1. 处理数据
L_dict, user_event_dict, event_user_dict = handle_cas(input_file, obs_time)
# 2. 计算最大锚点数(不超过总用户数的2%)
num_nodes = len(L_dict)
max_anchor_num = int(num_nodes * 0.02)
if anchor_budget > max_anchor_num:
print(f"最大锚点数 {max_anchor_num},预算调整为 {max_anchor_num}")
anchor_budget = max_anchor_num
# 3. 选择锚点
A = select_social_sensors(R, L_dict, user_event_dict, event_user_dict, anchor_budget)
# 4. 保存结果
with open(output_file, 'w') as file:
for anchor in A:
file.write(f"{anchor}\n")
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='社交网络锚点选择参数')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt',
type=str, help='级联事件文件路径')
parser.add_argument('--output_file', default='./anchors.txt',
type=str, help='锚点输出文件路径')
parser.add_argument('--anchor_budget', default=100,
type=int, help='锚点预算数量')
parser.add_argument('--obs_time', default=-1,
type=int, help='观察时间阈值(-1表示无限制)')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
generate_anchors(args.input_file, args.output_file,
args.anchor_budget, args.obs_time)