optic_graph.rs 12.31 KiB
use std::{cell::RefCell, rc::Rc};
use petgraph::{algo::is_cyclic_directed, prelude::DiGraph, stable_graph::NodeIndex};
use serde::{
de::{self, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Serialize,
};
use uuid::Uuid;
use crate::{
error::{OpmResult, OpossumError},
light::Light,
optical::Optical,
};
use crate::{optic_ref::OpticRef, properties::Proptype};
#[derive(Debug, Default, Clone)]
pub struct OpticGraph(pub DiGraph<OpticRef, Light>);
impl OpticGraph {
pub fn add_node<T: Optical + 'static>(&mut self, node: T) -> NodeIndex {
self.0
.add_node(OpticRef::new(Rc::new(RefCell::new(node)), None))
}
pub fn connect_nodes(
&mut self,
src_node: NodeIndex,
src_port: &str,
target_node: NodeIndex,
target_port: &str,
) -> OpmResult<()> {
let source = self
.0
.node_weight(src_node)
.ok_or(OpossumError::OpticScenery(
"source node with given index does not exist".into(),
))?;
if !source
.optical_ref
.borrow()
.ports()
.outputs()
.contains(&src_port.into())
{
return Err(OpossumError::OpticScenery(format!(
"source node {} does not have a port {}",
source.optical_ref.borrow().name(),
src_port
)));
}
let target = self
.0
.node_weight(target_node)
.ok_or(OpossumError::OpticScenery(
"target node with given index does not exist".into(),
))?;
if !target
.optical_ref
.borrow()
.ports()
.inputs()
.contains(&target_port.into())
{
return Err(OpossumError::OpticScenery(format!(
"target node {} does not have a port {}",
target.optical_ref.borrow().name(),
target_port
)));
}
if self.src_node_port_exists(src_node, src_port) {
return Err(OpossumError::OpticScenery(format!(
"src node <{}> with port <{}> is already connected",
source.optical_ref.borrow().name(),
src_port
)));
}
if self.target_node_port_exists(target_node, target_port) {
return Err(OpossumError::OpticScenery(format!(
"target node <{}> with port <{}> is already connected",
target.optical_ref.borrow().name(),
target_port
)));
}
let src_name = source.optical_ref.borrow().name().to_owned();
let target_name = target.optical_ref.borrow().name().to_owned();
let edge_index = self
.0
.add_edge(src_node, target_node, Light::new(src_port, target_port));
if is_cyclic_directed(&self.0) {
self.0.remove_edge(edge_index);
return Err(OpossumError::OpticScenery(format!(
"connecting nodes <{}> -> <{}> would form a loop",
src_name, target_name
)));
}
Ok(())
}
fn src_node_port_exists(&self, src_node: NodeIndex, src_port: &str) -> bool {
self.0
.edges_directed(src_node, petgraph::Direction::Outgoing)
.any(|e| e.weight().src_port() == src_port)
}
fn target_node_port_exists(&self, target_node: NodeIndex, target_port: &str) -> bool {
self.0
.edges_directed(target_node, petgraph::Direction::Incoming)
.any(|e| e.weight().target_port() == target_port)
}
pub fn node(&self, uuid: Uuid) -> Option<OpticRef> {
self.0
.node_weights()
.find(|node| node.uuid() == uuid)
.cloned()
}
pub fn node_idx(&self, uuid: Uuid) -> Option<NodeIndex> {
self.0
.node_indices()
.find(|idx| self.0.node_weight(*idx).unwrap().uuid() == uuid)
}
}
impl Serialize for OpticGraph {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let g = self.0.clone();
let mut graph = serializer.serialize_struct("graph", 2)?;
let nodes = g
.node_weights()
.map(|n| n.to_owned())
.collect::<Vec<OpticRef>>();
graph.serialize_field("nodes", &nodes)?;
let edgeidx = g
.edge_indices()
.map(|e| {
(
g.node_weight(g.edge_endpoints(e).unwrap().0)
.unwrap()
.uuid(),
g.node_weight(g.edge_endpoints(e).unwrap().1)
.unwrap()
.uuid(),
g.edge_weight(e).unwrap().src_port(),
g.edge_weight(e).unwrap().target_port(),
)
})
.collect::<Vec<(Uuid, Uuid, &str, &str)>>();
graph.serialize_field("edges", &edgeidx)?;
graph.end()
}
}
impl<'de> Deserialize<'de> for OpticGraph {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
enum Field {
Nodes,
Edges,
}
const FIELDS: &[&str] = &["nodes", "edges"];
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct FieldVisitor;
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("`nodes` or `edges`")
}
fn visit_str<E>(self, value: &str) -> std::result::Result<Field, E>
where
E: de::Error,
{
match value {
"nodes" => Ok(Field::Nodes),
"edges" => Ok(Field::Edges),
_ => Err(de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct OpticGraphVisitor;
impl<'de> Visitor<'de> for OpticGraphVisitor {
type Value = OpticGraph;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an OpticGraph")
}
fn visit_map<A>(self, mut map: A) -> std::result::Result<OpticGraph, A::Error>
where
A: MapAccess<'de>,
{
let mut g = OpticGraph::default();
let mut nodes: Option<Vec<OpticRef>> = None;
let mut edges: Option<Vec<(Uuid, Uuid, &str, &str)>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Nodes => {
if nodes.is_some() {
return Err(de::Error::duplicate_field("nodes"));
}
nodes = Some(map.next_value::<Vec<OpticRef>>()?);
}
Field::Edges => {
if edges.is_some() {
return Err(de::Error::duplicate_field("edges"));
}
edges = Some(map.next_value::<Vec<(Uuid, Uuid, &str, &str)>>()?);
}
}
}
let nodes = nodes.ok_or_else(|| de::Error::missing_field("nodes"))?;
let edges = edges.ok_or_else(|| de::Error::missing_field("edges"))?;
for node in nodes.iter() {
g.0.add_node(node.clone());
}
// assign references to ref nodes (if any)
for node in nodes.iter() {
if node.optical_ref.borrow().node_type() == "reference" {
let mut my_node = node.optical_ref.borrow_mut();
let refnode = my_node.as_refnode_mut().unwrap();
let node_props = refnode.properties().clone();
let uuid =
if let Proptype::Uuid(uuid) = node_props.get("reference id").unwrap() {
*uuid
} else {
Uuid::nil()
};
let ref_node = g.node(uuid).unwrap();
let ref_name = format!("ref ({})", ref_node.optical_ref.borrow().name());
refnode.assign_reference(ref_node);
refnode
.set_property("name", Proptype::String(ref_name))
.unwrap();
}
}
for edge in edges.iter() {
let src_idx = g.node_idx(edge.0).ok_or_else(|| {
de::Error::custom(format!("src id {} does not exist", edge.0))
})?;
let target_idx = g.node_idx(edge.1).ok_or_else(|| {
de::Error::custom(format!("target id {} does not exist", edge.1))
})?;
g.connect_nodes(src_idx, edge.2, target_idx, edge.3)
.map_err(|e| {
de::Error::custom(format!("connecting OpticGraph nodes failed: {}", e))
})?;
}
Ok(g)
}
}
deserializer.deserialize_struct("OpticGraph", FIELDS, OpticGraphVisitor)
}
}
impl From<OpticGraph> for Proptype {
fn from(value: OpticGraph) -> Self {
Proptype::OpticGraph(value)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::nodes::Dummy;
#[test]
fn add_node() {
let mut graph = OpticGraph::default();
graph.add_node(Dummy::new("n1"));
assert_eq!(graph.0.node_count(), 1);
}
#[test]
fn connect_nodes_ok() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::new("Test"));
let n2 = graph.add_node(Dummy::new("Test"));
assert!(graph.connect_nodes(n1, "rear", n2, "front").is_ok());
assert_eq!(graph.0.edge_count(), 1);
}
#[test]
fn connect_nodes_failure() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::new("Test"));
let n2 = graph.add_node(Dummy::new("Test"));
assert!(graph
.connect_nodes(n1, "rear", NodeIndex::new(5), "front")
.is_err());
assert!(graph
.connect_nodes(NodeIndex::new(5), "rear", n2, "front")
.is_err());
}
#[test]
fn connect_nodes_target_already_connected() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::new("Test"));
let n2 = graph.add_node(Dummy::new("Test"));
let n3 = graph.add_node(Dummy::new("Test"));
assert!(graph.connect_nodes(n1, "rear", n2, "front").is_ok());
assert!(graph.connect_nodes(n3, "rear", n2, "front").is_err());
}
#[test]
fn connect_nodes_loop_error() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::new("Test"));
let n2 = graph.add_node(Dummy::new("Test"));
assert!(graph.connect_nodes(n1, "rear", n2, "front").is_ok());
assert!(graph.connect_nodes(n2, "rear", n1, "front").is_err());
assert_eq!(graph.0.edge_count(), 1);
}
#[test]
fn node() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::default());
let uuid = graph.0.node_weight(n1).unwrap().uuid();
assert!(graph.node(uuid).is_some());
assert!(graph.node(Uuid::new_v4()).is_none());
}
#[test]
fn node_id() {
let mut graph = OpticGraph::default();
let n1 = graph.add_node(Dummy::default());
let uuid = graph.0.node_weight(n1).unwrap().uuid();
assert_eq!(graph.node_idx(uuid), Some(n1));
assert_eq!(graph.node_idx(Uuid::new_v4()), None);
}
}