renamed: dataset_for_anchor.txt -> data/dataset_for_anchor.txt modified: main.py modified: model/anchor_finder.py
74 lines
3.0 KiB
Python
74 lines
3.0 KiB
Python
import argparse
|
||
from collections import defaultdict
|
||
import json
|
||
import os
|
||
from model import anchor_finder
|
||
|
||
def handle_cas(args):
|
||
filename, result_path = args.input_file, args.result_path
|
||
print("Handle cascade begin!")
|
||
event_dict = defaultdict(list)
|
||
cascades_total = 0
|
||
with open(filename) as file:
|
||
for line in file:
|
||
cascades_total += 1
|
||
if cascades_total > args.obs_cas_num and args.obs_cas_num >0:
|
||
break
|
||
parts = line.split(',')
|
||
cascade_id = parts[0]
|
||
|
||
paths = parts[1:]
|
||
|
||
for p in paths:
|
||
# observed adoption/participant
|
||
n_t = p.split(':')
|
||
if len(n_t) < 2:
|
||
continue
|
||
nodes = n_t[0].split('/')
|
||
time_now = n_t[1]
|
||
|
||
node = nodes[-1]
|
||
node_id = node
|
||
|
||
u_t_dict = {
|
||
'user_id': node_id,
|
||
'act_time': time_now
|
||
}
|
||
event_dict[cascade_id].append(u_t_dict)
|
||
print(f"cascade_id:{cascade_id}")
|
||
|
||
print(f"Handle cascade file finish!")
|
||
os.makedirs(result_path, exist_ok=True)
|
||
json_file = os.path.join(result_path, f'event_dict.json')
|
||
try:
|
||
with open(json_file, 'w', encoding='utf-8') as f:
|
||
json.dump(event_dict, f, indent=4, ensure_ascii=False)
|
||
print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
|
||
except TypeError as e:
|
||
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
|
||
except IOError:
|
||
print(f"错误: 无法打开或写入文件 '{json_file}'。")
|
||
except Exception as e:
|
||
print(f"保存 JSON 文件时发生未知错误: {e}")
|
||
|
||
return event_dict
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description='Parameters')
|
||
parser.add_argument('--input_file', default='./data/dataset_for_anchor.txt', type=str, help='Cascade file')
|
||
parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
|
||
parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
|
||
parser.add_argument('--anchor_num', default=10, type=int, help='Anchors num')
|
||
parser.add_argument('--obs_cas_num', default=200, type=int, help='Anchors observe cas num')
|
||
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
|
||
parser.add_argument('--anchor_screen', default=1, type=float, help='anchor_screen')
|
||
return parser.parse_args()
|
||
|
||
if __name__ == '__main__':
|
||
args = parse_args()
|
||
|
||
# 该demo中给出生产event_dict示例,生产中须自行构建event_dict
|
||
event_dict = handle_cas(args)
|
||
|
||
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
|
||
A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, args.anchor_budget, args.anchor_screen, args.obs_time, args.result_path, args.anchor_num) |