Create ws plugin
Signed-off-by: Alex Kattathra Johnson <alex.kattathra.johnson@gmail.com>
This commit is contained in:
commit
ee23be4aac
8 changed files with 2843 additions and 0 deletions
107
src/main.rs
Normal file
107
src/main.rs
Normal file
|
@ -0,0 +1,107 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use nu_plugin::{EvaluatedCall, JsonSerializer, serve_plugin};
|
||||
use nu_plugin::{EngineInterface, Plugin, PluginCommand};
|
||||
use nu_protocol::{ByteStream, ByteStreamType, Category, LabeledError, PipelineData, Signature, SyntaxShape, Type, Value};
|
||||
|
||||
mod ws;
|
||||
use ws::client::{connect, http_parse_url, request_headers};
|
||||
|
||||
struct WebSocketPlugin;
|
||||
|
||||
impl Plugin for WebSocketPlugin {
|
||||
fn version(&self) -> String {
|
||||
env!("CARGO_PKG_VERSION").into()
|
||||
}
|
||||
|
||||
fn commands(&self) -> Vec<Box<dyn PluginCommand<Plugin = Self>>> {
|
||||
vec![
|
||||
Box::new(WebSocket),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSocket;
|
||||
|
||||
impl PluginCommand for WebSocket {
|
||||
type Plugin = WebSocketPlugin;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ws"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"streams output from a websocket"
|
||||
}
|
||||
|
||||
fn signature(&self) -> Signature {
|
||||
Signature::build(PluginCommand::name(self))
|
||||
.input_output_type(Type::String, Type::Int)
|
||||
.required(
|
||||
"URL",
|
||||
SyntaxShape::String,
|
||||
"The URL to stream from (ws:// or wss://).",
|
||||
)
|
||||
.named(
|
||||
"headers",
|
||||
SyntaxShape::Any,
|
||||
"custom headers you want to add ",
|
||||
Some('H'),
|
||||
)
|
||||
.named(
|
||||
"max-time",
|
||||
SyntaxShape::Duration,
|
||||
"max duration before timeout occurs",
|
||||
Some('m'),
|
||||
)
|
||||
.filter()
|
||||
.category(Category::Network)
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
_plugin: &Self::Plugin,
|
||||
engine: &EngineInterface,
|
||||
call: &EvaluatedCall,
|
||||
_input: PipelineData,
|
||||
) -> Result<PipelineData, LabeledError> {
|
||||
let url: Value = call.req(0)?;
|
||||
let headers: Option<Value> = call.get_flag("headers")?;
|
||||
let timeout: Option<Value> = call.get_flag("max-time")?;
|
||||
|
||||
let span = url.span();
|
||||
|
||||
let (_, requested_url) = http_parse_url(call, span, url)?;
|
||||
|
||||
if ["ws", "wss"].contains(&requested_url.scheme()) {
|
||||
let timeout = timeout.map(|ref val| {
|
||||
Duration::from_nanos(
|
||||
val.as_duration()
|
||||
.expect("Timeout should be set to duration") as u64,
|
||||
)
|
||||
});
|
||||
if let Some(cr) = connect(
|
||||
requested_url,
|
||||
timeout,
|
||||
request_headers(headers)?,
|
||||
) {
|
||||
let reader = Box::new(cr);
|
||||
return Ok(PipelineData::ByteStream(
|
||||
ByteStream::read(
|
||||
reader,
|
||||
span,
|
||||
engine.signals().clone(),
|
||||
ByteStreamType::Unknown,
|
||||
),
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Err(LabeledError::new("Unsupported input for command"))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
serve_plugin(&WebSocketPlugin, JsonSerializer)
|
||||
}
|
211
src/ws/client.rs
Normal file
211
src/ws/client.rs
Normal file
|
@ -0,0 +1,211 @@
|
|||
use nu_plugin::EvaluatedCall;
|
||||
use nu_protocol::{ShellError, Span, Value};
|
||||
use url::Url;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::collections::HashMap;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
io::Read,
|
||||
sync::{
|
||||
mpsc::{self, Receiver, TryRecvError},
|
||||
Arc, Mutex,
|
||||
},
|
||||
thread,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tungstenite::ClientRequestBuilder;
|
||||
|
||||
pub struct ChannelReader {
|
||||
rx: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
deadline: Option<Instant>,
|
||||
buf_deque: VecDeque<u8>,
|
||||
}
|
||||
|
||||
impl ChannelReader {
|
||||
pub fn new(rx: Receiver<Vec<u8>>, timeout: Option<Duration>) -> Self {
|
||||
let mut cr = Self {
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
deadline: None,
|
||||
buf_deque: VecDeque::new(),
|
||||
};
|
||||
if let Some(timeout) = timeout {
|
||||
cr.deadline = Some(Instant::now() + timeout);
|
||||
}
|
||||
cr
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for ChannelReader {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
let rx = self.rx.lock().expect("Could not get lock on receiver");
|
||||
|
||||
let bytes = match self.deadline {
|
||||
Some(deadline) => rx
|
||||
.recv_timeout(deadline.duration_since(Instant::now()))
|
||||
.map_err(|_| TryRecvError::Disconnected),
|
||||
None => rx.recv().map_err(|_| TryRecvError::Disconnected),
|
||||
};
|
||||
|
||||
let bytes = match bytes {
|
||||
Ok(bytes) => bytes,
|
||||
Err(..) => return Ok(0),
|
||||
};
|
||||
|
||||
for b in bytes {
|
||||
self.buf_deque.push_back(b);
|
||||
}
|
||||
|
||||
let mut len = 0;
|
||||
for buf in buf {
|
||||
if let Some(b) = self.buf_deque.pop_front() {
|
||||
*buf = b;
|
||||
len += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connect(
|
||||
url: Url,
|
||||
timeout: Option<Duration>,
|
||||
headers: HashMap<String, String>,
|
||||
) -> Option<ChannelReader> {
|
||||
let mut builder = ClientRequestBuilder::new(url.as_str().parse().ok()?);
|
||||
builder = builder.with_header(
|
||||
"Origin",
|
||||
format!(
|
||||
"{}://{}:{}",
|
||||
url.scheme(),
|
||||
url.host_str().unwrap_or_default(),
|
||||
url.port().unwrap_or_default()
|
||||
),
|
||||
);
|
||||
for (k, v) in headers {
|
||||
builder = builder.with_header(k, v);
|
||||
}
|
||||
match tungstenite::connect(builder) {
|
||||
Ok((mut websocket, _)) => {
|
||||
let (tx, rx) = mpsc::sync_channel(1024);
|
||||
let tx = Arc::new(tx);
|
||||
thread::Builder::new()
|
||||
.name("websocket response sender".to_string())
|
||||
.spawn(move || loop {
|
||||
let tx = tx.clone();
|
||||
match websocket.read() {
|
||||
Ok(msg) => match msg {
|
||||
tungstenite::Message::Text(msg) => {
|
||||
if tx.send(msg.as_bytes().to_vec()).is_err() {
|
||||
websocket.close(Some(tungstenite::protocol::CloseFrame{
|
||||
code: tungstenite::protocol::frame::coding::CloseCode::Normal,
|
||||
reason: std::borrow::Cow::Borrowed("byte stream closed"),
|
||||
})).expect("Could not close connection")
|
||||
}
|
||||
}
|
||||
tungstenite::Message::Binary(msg) => {
|
||||
if tx.send(msg).is_err() {
|
||||
websocket.close(Some(tungstenite::protocol::CloseFrame{
|
||||
code: tungstenite::protocol::frame::coding::CloseCode::Normal,
|
||||
reason: std::borrow::Cow::Borrowed("byte stream closed"),
|
||||
})).expect("Could not close connection")
|
||||
}
|
||||
}
|
||||
tungstenite::Message::Close(..) => {
|
||||
drop(tx);
|
||||
return;
|
||||
}
|
||||
_ => continue,
|
||||
},
|
||||
_ => {
|
||||
drop(tx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok()?;
|
||||
Some(ChannelReader::new(rx, timeout))
|
||||
}
|
||||
Err(..) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn http_parse_url(
|
||||
call: &EvaluatedCall,
|
||||
span: Span,
|
||||
raw_url: Value,
|
||||
) -> Result<(String, Url), ShellError> {
|
||||
let requested_url = raw_url.coerce_into_string()?;
|
||||
let url = match Url::parse(&requested_url) {
|
||||
Ok(u) => u,
|
||||
Err(_e) => {
|
||||
return Err(ShellError::UnsupportedInput { msg: "Incomplete or incorrect URL. Expected a full URL, e.g., https://www.example.com"
|
||||
.to_string(), input: format!("value: '{requested_url:?}'"), msg_span: call.head, input_span: span });
|
||||
}
|
||||
};
|
||||
|
||||
Ok((requested_url, url))
|
||||
}
|
||||
|
||||
pub fn request_headers(headers: Option<Value>) -> Result<HashMap<String, String>, ShellError> {
|
||||
let mut custom_headers: HashMap<String, Value> = HashMap::new();
|
||||
|
||||
if let Some(headers) = headers {
|
||||
match &headers {
|
||||
Value::Record { val, .. } => {
|
||||
for (k, v) in &**val {
|
||||
custom_headers.insert(k.to_string(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
Value::List { vals: table, .. } => {
|
||||
if table.len() == 1 {
|
||||
// single row([key1 key2]; [val1 val2])
|
||||
match &table[0] {
|
||||
Value::Record { val, .. } => {
|
||||
for (k, v) in &**val {
|
||||
custom_headers.insert(k.to_string(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
x => {
|
||||
return Err(ShellError::CantConvert {
|
||||
to_type: "string list or single row".into(),
|
||||
from_type: x.get_type().to_string(),
|
||||
span: headers.span(),
|
||||
help: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// primitive values ([key1 val1 key2 val2])
|
||||
for row in table.chunks(2) {
|
||||
if row.len() == 2 {
|
||||
custom_headers.insert(row[0].coerce_string()?, row[1].clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
x => {
|
||||
return Err(ShellError::CantConvert {
|
||||
to_type: "string list or single row".into(),
|
||||
from_type: x.get_type().to_string(),
|
||||
span: headers.span(),
|
||||
help: None,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut result = HashMap::new();
|
||||
for (k, v) in custom_headers {
|
||||
if let Ok(s) = v.coerce_into_string() {
|
||||
result.insert(k, s);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
1
src/ws/mod.rs
Normal file
1
src/ws/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod client;
|
Loading…
Add table
Add a link
Reference in a new issue