78 lines
3.4 KiB
Python
78 lines
3.4 KiB
Python
|
|
import argparse
|
||
|
|
from collections import defaultdict
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from model import anchor_finder
|
||
|
|
|
||
|
|
def handle_cas(args):
|
||
|
|
filename, result_path = args.input_file, args.result_path
|
||
|
|
print("Handle cascade begin!")
|
||
|
|
event_dict = defaultdict(dict)
|
||
|
|
|
||
|
|
with open(filename) as file:
|
||
|
|
for line in file:
|
||
|
|
cascades_total += 1
|
||
|
|
|
||
|
|
parts = line.split(',')
|
||
|
|
cascade_id = parts[0]
|
||
|
|
|
||
|
|
paths = parts[1:]
|
||
|
|
|
||
|
|
for p in paths:
|
||
|
|
# observed adoption/participant
|
||
|
|
n_t = p.split(':')
|
||
|
|
if len(n_t) < 2:
|
||
|
|
continue
|
||
|
|
nodes = n_t[0].split('/')
|
||
|
|
time_now = n_t[1]
|
||
|
|
|
||
|
|
node = nodes[-1]
|
||
|
|
node_id = node
|
||
|
|
|
||
|
|
u_t_dict = {
|
||
|
|
'user_id': node_id,
|
||
|
|
'act_time': time_now
|
||
|
|
}
|
||
|
|
event_dict[cascade_id].append(u_t_dict)
|
||
|
|
print(f"cascade_id:{cascade_id}")
|
||
|
|
|
||
|
|
print(f"Handle cascade file finish!")
|
||
|
|
os.makedirs(result_path, exist_ok=True)
|
||
|
|
json_file = os.path.join(result_path, f'event_dict.json')
|
||
|
|
try:
|
||
|
|
with open(json_file, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(event_dict, f, indent=4, ensure_ascii=False)
|
||
|
|
print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
|
||
|
|
except TypeError as e:
|
||
|
|
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
|
||
|
|
except IOError:
|
||
|
|
print(f"错误: 无法打开或写入文件 '{json_file}'。")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"保存 JSON 文件时发生未知错误: {e}")
|
||
|
|
|
||
|
|
return event_dict
|
||
|
|
|
||
|
|
def parse_args():
|
||
|
|
parser = argparse.ArgumentParser(description='Parameters')
|
||
|
|
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
|
||
|
|
parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
|
||
|
|
parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
|
||
|
|
parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
|
||
|
|
parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
|
||
|
|
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
|
||
|
|
parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
|
||
|
|
parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
|
||
|
|
parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
|
||
|
|
parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
|
||
|
|
parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
args = parse_args()
|
||
|
|
|
||
|
|
input_file, result_path, anchor_budget, obs_time, anchor_num = args.input_file, args.result_path, args.anchor_budget, args.obs_time, args.anchor_num
|
||
|
|
event_dict = handle_cas(args)
|
||
|
|
print(f"anchor_num is set to {anchor_num}")
|
||
|
|
|
||
|
|
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
|
||
|
|
A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num)
|