diff --git a/.gitignore b/.gitignore index dbd8938..fcb20e1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ !/.vs/settings.json /logging/ Cargo.lock -Conf.toml \ No newline at end of file +Conf.toml +/data \ No newline at end of file diff --git a/src/commands/general.rs b/src/commands/general.rs index 52d3b3c..482bc9f 100644 --- a/src/commands/general.rs +++ b/src/commands/general.rs @@ -10,7 +10,7 @@ use serenity::{ }; #[group] -#[commands(longcode, image, older, ping, invite, infos, error)] +#[commands(longcode, image, older, ping, invite, infos, error, send_message)] pub struct General; #[command] @@ -288,3 +288,13 @@ impl std::str::FromStr for Image { } } } + +#[command] +#[owners_only] +async fn send_message(ctx: &Context, _msg: &Message, mut args: Args) -> CommandResult { + let channel_id = args.single::()?; + let message = args.single::()?; + debugln!("Send {} into {:?}", message, channel_id); + channel_id.say(ctx, message).await?; + Ok(()) +} diff --git a/src/commands/roulette.rs b/src/commands/roulette.rs index 48b3c3a..d06cdde 100644 --- a/src/commands/roulette.rs +++ b/src/commands/roulette.rs @@ -8,18 +8,24 @@ use serenity::{ model::prelude::*, prelude::*, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; -pub struct BulletsContainer; +pub(crate) struct BulletsContainer; impl TypeMapKey for BulletsContainer { type Value = HashMap; } +pub(crate) struct NonKickGuildsContainer; + +impl TypeMapKey for NonKickGuildsContainer { + type Value = HashSet; +} + #[group] #[default_command(shot)] #[prefix("roulette")] -#[commands(reload, shot, check, kick)] +#[commands(reload, shot, check, kick, disable_kick)] struct Roulette; #[command] @@ -115,7 +121,6 @@ fn bullet_to_str<'m>(nbr: u8) -> &'m str { } } -// TODO ALLOW ADMINS TO DISABLE THAT #[command] #[description = "DO IT"] #[only_in(guilds)] @@ -133,32 +138,72 @@ async fn kick(ctx: &Context, msg: &Message) -> CommandResult { } async fn _kick(ctx: &Context, msg: &Message) -> commands::Result<()> { - let mut data = ctx.data.write().await; - let bullets_map = data - .get_mut::() - .expect("Expected CommandCounter in TypeMap."); - let bullets = bullets_map - .entry(msg.author.id.0) - .or_insert((5, rand::thread_rng().gen_range(0, 6))); - if bullets.0 == bullets.1 { - api::send_reply(ctx, &msg, "💥").await?; - *bullets = (5, rand::thread_rng().gen_range(0, 6)); - if let Some(guild_id) = &msg.guild_id { - guild_id - .member(&ctx.http, &msg.author) - .await? - .kick_with_reason(&ctx.http, "You loose at the roulette") + if let Some(guild_id) = &msg.guild_id { + let mut data = ctx.data.write().await; + let non_kick_guilds = data + .get_mut::() + .expect("Expected NonKickGuildsContainer in TypeMap."); + if non_kick_guilds.contains(guild_id.as_u64()) { + msg.channel_id + .say( + ctx, + "Error : You cannot play to the REAL RUSSIAN ROULETTE in this guild", + ) .await?; + } else { + let bullets_map = data + .get_mut::() + .expect("Expected CommandCounter in TypeMap."); + let bullets = bullets_map + .entry(msg.author.id.0) + .or_insert((5, rand::thread_rng().gen_range(0, 6))); + if bullets.0 == bullets.1 { + api::send_reply(ctx, &msg, "💥").await?; + *bullets = (5, rand::thread_rng().gen_range(0, 6)); + + guild_id + .member(&ctx.http, &msg.author) + .await? + .kick_with_reason(&ctx.http, "You loose at the roulette") + .await?; + } else { + *bullets = (bullets.0 - 1, bullets.1); + api::send_reply( + &ctx, + &msg, + format!("Click ! bullets remaining : {}", bullets.0 + 1), + ) + .await?; + } + debugln!("Bullets Map : {:?}", bullets_map); + } + } + Ok(()) +} + +#[command] +#[description = "Disable kicking"] +#[only_in(guilds)] +#[required_permissions("ADMINISTRATOR")] +async fn disable_kick(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { + let disable = match args.len() { + 0 => true, + _ => args.single::()?, + }; + let mut data = ctx.data.write().await; + let non_kick_guilds = data + .get_mut::() + .expect("Expected NonKickGuildsContainer in TypeMap."); + + if let Some(guild_id) = msg.guild_id { + let id = *guild_id.as_u64(); + if disable { + non_kick_guilds.insert(id); + msg.channel_id.say(ctx, "No fun allowed").await?; + } else { + non_kick_guilds.remove(&id); + msg.channel_id.say(ctx, "Done").await?; } - } else { - *bullets = (bullets.0 - 1, bullets.1); - api::send_reply( - &ctx, - &msg, - format!("Click ! bullets remaining : {}", bullets.0 + 1), - ) - .await?; } - debugln!("Bullets Map : {:?}", bullets_map); Ok(()) } diff --git a/src/main.rs b/src/main.rs index 9d9a527..e6209b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,11 @@ use crate::commands::{ general::GENERAL_GROUP, - roulette::{BulletsContainer, ROULETTE_GROUP}, + roulette::{BulletsContainer, NonKickGuildsContainer, ROULETTE_GROUP}, }; use async_trait::async_trait; use serde_json::Value; use serenity::{ + client::bridge::gateway::ShardManager, framework::standard::{ help_commands, macros::{help, hook}, @@ -20,6 +21,7 @@ use std::{ fs::{self}, io::Result as IoResult, path::Path, + sync::Arc, time::Duration, }; use tokio::{fs::File, io::AsyncWriteExt}; @@ -37,6 +39,12 @@ const PREFIX: &str = "?"; static mut LOG_ATTACHMENTS: bool = false; pub(crate) static mut INVITE_URL: Option = None; +struct ShardManagerContainer; + +impl TypeMapKey for ShardManagerContainer { + type Value = Arc>; +} + //TODO CLAP FOR CLI #[tokio::main] async fn main() -> IoResult<()> { @@ -64,6 +72,11 @@ async fn main() -> IoResult<()> { fs::create_dir(dir)?; } + let data_dir = Path::new("data"); + if !data_dir.exists() { + fs::create_dir(data_dir)?; + } + let http = Http::new_with_token(&token); // We will fetch your bot's owners and id @@ -92,9 +105,6 @@ async fn main() -> IoResult<()> { // are owners only. .owners(owners) }) - // Set a function that's called whenever an attempted command-call's - // command could not be found. - .unrecognised_command(unknown_command) // Set a function that's called whenever a command's execution didn't complete for one // reason or another. For example, when a user has exceeded a rate-limit or a command // can only be performed by the bot owner. @@ -111,7 +121,11 @@ async fn main() -> IoResult<()> { if cfg!(debug_assertions) { // Set a function that's called whenever a message is not a command. - framework = framework.normal_message(normal_message) + framework = framework + .normal_message(normal_message) + // Set a function that's called whenever an attempted command-call's + // command could not be found. + .unrecognised_command(unknown_command) } let mut client = Client::new(&token) @@ -123,10 +137,18 @@ async fn main() -> IoResult<()> { { let mut data = client.data.write().await; data.insert::(HashMap::default()); + data.insert::(Arc::clone(&client.shard_manager)); #[cfg(feature = "music")] { data.insert::(std::sync::Arc::clone(&client.voice_manager)); } + + let non_kick_guilds = data_dir.join("nonkickguilds.json"); + data.insert::(if non_kick_guilds.exists() { + serde_json::from_reader(fs::File::open(non_kick_guilds)?)? + } else { + HashSet::new() + }) } client.start().await.unwrap(); @@ -201,6 +223,23 @@ impl EventHandler for Messages { tokio::time::delay_for(delay).await; } }); + + tokio::signal::ctrl_c().await.unwrap(); + debugln!("ctrl-c"); + let data = ctx.data.read().await; + + println!("Saving data ..."); + if let Err(e) = save_data(&data).await { + eprintln!("Error while saving data : {:?}", e); + } + println!("Data saved"); + + if let Some(manager) = data.get::() { + manager.lock().await.shutdown_all().await; + println!("Stopped"); + } else { + eprintln!("There was a problem getting the shard manager"); + } } async fn message(&self, _ctx: Context, new_message: Message) { @@ -218,6 +257,21 @@ impl EventHandler for Messages { } } +async fn save_data(data: &tokio::sync::RwLockReadGuard<'_, TypeMap>) -> commands::Result<()> { + let data_path = Path::new("data"); + + if let Some(data) = data.get::() { + let mut f = File::create(data_path.join("nonkickguilds.json")).await?; + let json = if cfg!(debug_assertions) { + serde_json::to_string_pretty(data)? + } else { + serde_json::to_string(data)? + }; + f.write_all(&json.as_bytes()).await?; + } + Ok(()) +} + async fn download_to_log(attachment: Attachment) -> commands::Result<()> { debugln!("Download_to_log : {:?}", attachment); let path = Path::new("logging").join(format!("{}-{}", attachment.id, attachment.filename));