mdbs/main.py
bill_juju 1f72b0e6c9 new file: main.py
new file:   model/anchor_finder.py
2025-08-07 11:26:36 +08:00

78 lines
3.4 KiB
Python

import argparse
from collections import defaultdict
import json
import os
from model import anchor_finder
def handle_cas(args):
filename, result_path = args.input_file, args.result_path
print("Handle cascade begin!")
event_dict = defaultdict(dict)
with open(filename) as file:
for line in file:
cascades_total += 1
parts = line.split(',')
cascade_id = parts[0]
paths = parts[1:]
for p in paths:
# observed adoption/participant
n_t = p.split(':')
if len(n_t) < 2:
continue
nodes = n_t[0].split('/')
time_now = n_t[1]
node = nodes[-1]
node_id = node
u_t_dict = {
'user_id': node_id,
'act_time': time_now
}
event_dict[cascade_id].append(u_t_dict)
print(f"cascade_id:{cascade_id}")
print(f"Handle cascade file finish!")
os.makedirs(result_path, exist_ok=True)
json_file = os.path.join(result_path, f'event_dict.json')
try:
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(event_dict, f, indent=4, ensure_ascii=False)
print(f"user_event_dict已成功保存到 JSON 文件: {json_file}")
except TypeError as e:
print(f"错误: 字典中可能包含无法序列化为 JSON 的数据类型。错误: {e}")
except IOError:
print(f"错误: 无法打开或写入文件 '{json_file}'")
except Exception as e:
print(f"保存 JSON 文件时发生未知错误: {e}")
return event_dict
def parse_args():
parser = argparse.ArgumentParser(description='Parameters')
parser.add_argument('--input_file', default='./dataset_for_anchor.txt', type=str, help='Cascade file')
parser.add_argument('--result_path', default='./result/', type=str, help='result save path')
parser.add_argument('--anchor_budget', default=0.005, type=float, help='Anchors proportion')
parser.add_argument('--anchor_num', default=1000, type=int, help='Anchors num')
parser.add_argument('--obs_cas_num', default=-1, type=int, help='Anchors observe cas num')
parser.add_argument('--obs_time', default=-1, type=int, help='Anchors observe time, default seeting is -1, meaning can observe all')
parser.add_argument('--key_node_file', default='/home/zjy/project/CasDO/dataset/nh/key_node_10', type=str, help='key_node_json save file')
parser.add_argument('--anchor_screen', default=-0.8, type=float, help='anchor_screen')
parser.add_argument('--cas_min_length', default=20, type=int, help='Cas min length')
parser.add_argument('--t_max', default=1751299200, type=int, help='Cas min length')
parser.add_argument('--t_min', default=1751299200, type=int, help='Cas min length')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
input_file, result_path, anchor_budget, obs_time, anchor_num = args.input_file, args.result_path, args.anchor_budget, args.obs_time, args.anchor_num
event_dict = handle_cas(args)
print(f"anchor_num is set to {anchor_num}")
#===========执行m点挖掘算法, 返回m点以及感知的级联数量=================
A, all_cas, user_event_dict = anchor_finder.generate_anchors(event_dict, anchor_budget, obs_time, result_path, anchor_num)