aboutsummaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/virtio_transport.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/virtio_transport.c')
-rw-r--r--net/vmw_vsock/virtio_transport.c265
1 files changed, 168 insertions, 97 deletions
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 96ab344f17bb..cc70d651d13e 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -39,6 +39,7 @@ struct virtio_vsock {
* must be accessed with tx_lock held.
*/
struct mutex tx_lock;
+ bool tx_run;
struct work_struct send_pkt_work;
spinlock_t send_pkt_list_lock;
@@ -54,6 +55,7 @@ struct virtio_vsock {
* must be accessed with rx_lock held.
*/
struct mutex rx_lock;
+ bool rx_run;
int rx_buf_nr;
int rx_buf_max_nr;
@@ -61,46 +63,28 @@ struct virtio_vsock {
* vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
*/
struct mutex event_lock;
+ bool event_run;
struct virtio_vsock_event event_list[8];
u32 guest_cid;
};
-static struct virtio_vsock *virtio_vsock_get(void)
-{
- return the_virtio_vsock;
-}
-
static u32 virtio_transport_get_local_cid(void)
{
- struct virtio_vsock *vsock = virtio_vsock_get();
-
- if (!vsock)
- return VMADDR_CID_ANY;
-
- return vsock->guest_cid;
-}
-
-static void virtio_transport_loopback_work(struct work_struct *work)
-{
- struct virtio_vsock *vsock =
- container_of(work, struct virtio_vsock, loopback_work);
- LIST_HEAD(pkts);
-
- spin_lock_bh(&vsock->loopback_list_lock);
- list_splice_init(&vsock->loopback_list, &pkts);
- spin_unlock_bh(&vsock->loopback_list_lock);
-
- mutex_lock(&vsock->rx_lock);
- while (!list_empty(&pkts)) {
- struct virtio_vsock_pkt *pkt;
-
- pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
- list_del_init(&pkt->list);
+ struct virtio_vsock *vsock;
+ u32 ret;
- virtio_transport_recv_pkt(pkt);
+ rcu_read_lock();
+ vsock = rcu_dereference(the_virtio_vsock);
+ if (!vsock) {
+ ret = VMADDR_CID_ANY;
+ goto out_rcu;
}
- mutex_unlock(&vsock->rx_lock);
+
+ ret = vsock->guest_cid;
+out_rcu:
+ rcu_read_unlock();
+ return ret;
}
static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,
@@ -128,6 +112,9 @@ virtio_transport_send_pkt_work(struct work_struct *work)
mutex_lock(&vsock->tx_lock);
+ if (!vsock->tx_run)
+ goto out;
+
vq = vsock->vqs[VSOCK_VQ_TX];
for (;;) {
@@ -186,6 +173,7 @@ virtio_transport_send_pkt_work(struct work_struct *work)
if (added)
virtqueue_kick(vq);
+out:
mutex_unlock(&vsock->tx_lock);
if (restart_rx)
@@ -198,14 +186,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
struct virtio_vsock *vsock;
int len = pkt->len;
- vsock = virtio_vsock_get();
+ rcu_read_lock();
+ vsock = rcu_dereference(the_virtio_vsock);
if (!vsock) {
virtio_transport_free_pkt(pkt);
- return -ENODEV;
+ len = -ENODEV;
+ goto out_rcu;
}
- if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
- return virtio_transport_send_pkt_loopback(vsock, pkt);
+ if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
+ len = virtio_transport_send_pkt_loopback(vsock, pkt);
+ goto out_rcu;
+ }
if (pkt->reply)
atomic_inc(&vsock->queued_replies);
@@ -215,6 +207,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
spin_unlock_bh(&vsock->send_pkt_list_lock);
queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
+
+out_rcu:
+ rcu_read_unlock();
return len;
}
@@ -223,12 +218,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
{
struct virtio_vsock *vsock;
struct virtio_vsock_pkt *pkt, *n;
- int cnt = 0;
+ int cnt = 0, ret;
LIST_HEAD(freeme);
- vsock = virtio_vsock_get();
+ rcu_read_lock();
+ vsock = rcu_dereference(the_virtio_vsock);
if (!vsock) {
- return -ENODEV;
+ ret = -ENODEV;
+ goto out_rcu;
}
spin_lock_bh(&vsock->send_pkt_list_lock);
@@ -256,7 +253,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
- return 0;
+ ret = 0;
+
+out_rcu:
+ rcu_read_unlock();
+ return ret;
}
static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
@@ -308,6 +309,10 @@ static void virtio_transport_tx_work(struct work_struct *work)
vq = vsock->vqs[VSOCK_VQ_TX];
mutex_lock(&vsock->tx_lock);
+
+ if (!vsock->tx_run)
+ goto out;
+
do {
struct virtio_vsock_pkt *pkt;
unsigned int len;
@@ -318,6 +323,8 @@ static void virtio_transport_tx_work(struct work_struct *work)
added = true;
}
} while (!virtqueue_enable_cb(vq));
+
+out:
mutex_unlock(&vsock->tx_lock);
if (added)
@@ -336,56 +343,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
return val < virtqueue_get_vring_size(vq);
}
-static void virtio_transport_rx_work(struct work_struct *work)
-{
- struct virtio_vsock *vsock =
- container_of(work, struct virtio_vsock, rx_work);
- struct virtqueue *vq;
-
- vq = vsock->vqs[VSOCK_VQ_RX];
-
- mutex_lock(&vsock->rx_lock);
-
- do {
- virtqueue_disable_cb(vq);
- for (;;) {
- struct virtio_vsock_pkt *pkt;
- unsigned int len;
-
- if (!virtio_transport_more_replies(vsock)) {
- /* Stop rx until the device processes already
- * pending replies. Leave rx virtqueue
- * callbacks disabled.
- */
- goto out;
- }
-
- pkt = virtqueue_get_buf(vq, &len);
- if (!pkt) {
- break;
- }
-
- vsock->rx_buf_nr--;
-
- /* Drop short/long packets */
- if (unlikely(len < sizeof(pkt->hdr) ||
- len > sizeof(pkt->hdr) + pkt->len)) {
- virtio_transport_free_pkt(pkt);
- continue;
- }
-
- pkt->len = len - sizeof(pkt->hdr);
- virtio_transport_deliver_tap_pkt(pkt);
- virtio_transport_recv_pkt(pkt);
- }
- } while (!virtqueue_enable_cb(vq));
-
-out:
- if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
- virtio_vsock_rx_fill(vsock);
- mutex_unlock(&vsock->rx_lock);
-}
-
/* event_lock must be held */
static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
struct virtio_vsock_event *event)
@@ -455,6 +412,9 @@ static void virtio_transport_event_work(struct work_struct *work)
mutex_lock(&vsock->event_lock);
+ if (!vsock->event_run)
+ goto out;
+
do {
struct virtio_vsock_event *event;
unsigned int len;
@@ -469,7 +429,7 @@ static void virtio_transport_event_work(struct work_struct *work)
} while (!virtqueue_enable_cb(vq));
virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
-
+out:
mutex_unlock(&vsock->event_lock);
}
@@ -546,6 +506,86 @@ static struct virtio_transport virtio_transport = {
.send_pkt = virtio_transport_send_pkt,
};
+static void virtio_transport_loopback_work(struct work_struct *work)
+{
+ struct virtio_vsock *vsock =
+ container_of(work, struct virtio_vsock, loopback_work);
+ LIST_HEAD(pkts);
+
+ spin_lock_bh(&vsock->loopback_list_lock);
+ list_splice_init(&vsock->loopback_list, &pkts);
+ spin_unlock_bh(&vsock->loopback_list_lock);
+
+ mutex_lock(&vsock->rx_lock);
+
+ if (!vsock->rx_run)
+ goto out;
+
+ while (!list_empty(&pkts)) {
+ struct virtio_vsock_pkt *pkt;
+
+ pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
+ list_del_init(&pkt->list);
+
+ virtio_transport_recv_pkt(&virtio_transport, pkt);
+ }
+out:
+ mutex_unlock(&vsock->rx_lock);
+}
+
+static void virtio_transport_rx_work(struct work_struct *work)
+{
+ struct virtio_vsock *vsock =
+ container_of(work, struct virtio_vsock, rx_work);
+ struct virtqueue *vq;
+
+ vq = vsock->vqs[VSOCK_VQ_RX];
+
+ mutex_lock(&vsock->rx_lock);
+
+ if (!vsock->rx_run)
+ goto out;
+
+ do {
+ virtqueue_disable_cb(vq);
+ for (;;) {
+ struct virtio_vsock_pkt *pkt;
+ unsigned int len;
+
+ if (!virtio_transport_more_replies(vsock)) {
+ /* Stop rx until the device processes already
+ * pending replies. Leave rx virtqueue
+ * callbacks disabled.
+ */
+ goto out;
+ }
+
+ pkt = virtqueue_get_buf(vq, &len);
+ if (!pkt) {
+ break;
+ }
+
+ vsock->rx_buf_nr--;
+
+ /* Drop short/long packets */
+ if (unlikely(len < sizeof(pkt->hdr) ||
+ len > sizeof(pkt->hdr) + pkt->len)) {
+ virtio_transport_free_pkt(pkt);
+ continue;
+ }
+
+ pkt->len = len - sizeof(pkt->hdr);
+ virtio_transport_deliver_tap_pkt(pkt);
+ virtio_transport_recv_pkt(&virtio_transport, pkt);
+ }
+ } while (!virtqueue_enable_cb(vq));
+
+out:
+ if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
+ virtio_vsock_rx_fill(vsock);
+ mutex_unlock(&vsock->rx_lock);
+}
+
static int virtio_vsock_probe(struct virtio_device *vdev)
{
vq_callback_t *callbacks[] = {
@@ -566,7 +606,8 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
return ret;
/* Only one virtio-vsock device per guest is supported */
- if (the_virtio_vsock) {
+ if (rcu_dereference_protected(the_virtio_vsock,
+ lockdep_is_held(&the_virtio_vsock_mutex))) {
ret = -EBUSY;
goto out;
}
@@ -591,8 +632,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
vsock->rx_buf_max_nr = 0;
atomic_set(&vsock->queued_replies, 0);
- vdev->priv = vsock;
- the_virtio_vsock = vsock;
mutex_init(&vsock->tx_lock);
mutex_init(&vsock->rx_lock);
mutex_init(&vsock->event_lock);
@@ -606,14 +645,23 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work);
+ mutex_lock(&vsock->tx_lock);
+ vsock->tx_run = true;
+ mutex_unlock(&vsock->tx_lock);
+
mutex_lock(&vsock->rx_lock);
virtio_vsock_rx_fill(vsock);
+ vsock->rx_run = true;
mutex_unlock(&vsock->rx_lock);
mutex_lock(&vsock->event_lock);
virtio_vsock_event_fill(vsock);
+ vsock->event_run = true;
mutex_unlock(&vsock->event_lock);
+ vdev->priv = vsock;
+ rcu_assign_pointer(the_virtio_vsock, vsock);
+
mutex_unlock(&the_virtio_vsock_mutex);
return 0;
@@ -628,6 +676,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
struct virtio_vsock *vsock = vdev->priv;
struct virtio_vsock_pkt *pkt;
+ mutex_lock(&the_virtio_vsock_mutex);
+
+ vdev->priv = NULL;
+ rcu_assign_pointer(the_virtio_vsock, NULL);
+ synchronize_rcu();
+
flush_work(&vsock->loopback_work);
flush_work(&vsock->rx_work);
flush_work(&vsock->tx_work);
@@ -637,6 +691,24 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
/* Reset all connected sockets when the device disappear */
vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+ /* Stop all work handlers to make sure no one is accessing the device,
+ * so we can safely call vdev->config->reset().
+ */
+ mutex_lock(&vsock->rx_lock);
+ vsock->rx_run = false;
+ mutex_unlock(&vsock->rx_lock);
+
+ mutex_lock(&vsock->tx_lock);
+ vsock->tx_run = false;
+ mutex_unlock(&vsock->tx_lock);
+
+ mutex_lock(&vsock->event_lock);
+ vsock->event_run = false;
+ mutex_unlock(&vsock->event_lock);
+
+ /* Flush all device writes and interrupts, device will not use any
+ * more buffers.
+ */
vdev->config->reset(vdev);
mutex_lock(&vsock->rx_lock);
@@ -667,12 +739,11 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
}
spin_unlock_bh(&vsock->loopback_list_lock);
- mutex_lock(&the_virtio_vsock_mutex);
- the_virtio_vsock = NULL;
- mutex_unlock(&the_virtio_vsock_mutex);
-
+ /* Delete virtqueues and flush outstanding callbacks if any */
vdev->config->del_vqs(vdev);
+ mutex_unlock(&the_virtio_vsock_mutex);
+
kfree(vsock);
}