diff options
Diffstat (limited to 'net/9p')
-rw-r--r-- | net/9p/client.c | 262 | ||||
-rw-r--r-- | net/9p/protocol.c | 28 | ||||
-rw-r--r-- | net/9p/trans_common.c | 42 | ||||
-rw-r--r-- | net/9p/trans_common.h | 2 | ||||
-rw-r--r-- | net/9p/trans_fd.c | 7 | ||||
-rw-r--r-- | net/9p/trans_rdma.c | 52 | ||||
-rw-r--r-- | net/9p/trans_virtio.c | 142 |
7 files changed, 264 insertions, 271 deletions
diff --git a/net/9p/client.c b/net/9p/client.c index e86a9bea1d16..6f4c4c88db84 100644 --- a/net/9p/client.c +++ b/net/9p/client.c @@ -34,6 +34,7 @@ #include <linux/slab.h> #include <linux/sched.h> #include <linux/uaccess.h> +#include <linux/uio.h> #include <net/9p/9p.h> #include <linux/parser.h> #include <net/9p/client.h> @@ -555,7 +556,7 @@ out_err: */ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, - char *uidata, int in_hdrlen, int kern_buf) + struct iov_iter *uidata, int in_hdrlen) { int err; int ecode; @@ -591,16 +592,11 @@ static int p9_check_zc_errors(struct p9_client *c, struct p9_req_t *req, ename = &req->rc->sdata[req->rc->offset]; if (len > inline_len) { /* We have error in external buffer */ - if (kern_buf) { - memcpy(ename + inline_len, uidata, - len - inline_len); - } else { - err = copy_from_user(ename + inline_len, - uidata, len - inline_len); - if (err) { - err = -EFAULT; - goto out_err; - } + err = copy_from_iter(ename + inline_len, + len - inline_len, uidata); + if (err != len - inline_len) { + err = -EFAULT; + goto out_err; } } ename = NULL; @@ -806,8 +802,8 @@ reterr: * p9_client_zc_rpc - issue a request and wait for a response * @c: client session * @type: type of request - * @uidata: user bffer that should be ued for zero copy read - * @uodata: user buffer that shoud be user for zero copy write + * @uidata: destination for zero copy read + * @uodata: source for zero copy write * @inlen: read buffer size * @olen: write buffer size * @hdrlen: reader header size, This is the size of response protocol data @@ -816,9 +812,10 @@ reterr: * Returns request structure (which client must free using p9_free_req) */ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, - char *uidata, char *uodata, + struct iov_iter *uidata, + struct iov_iter *uodata, int inlen, int olen, int in_hdrlen, - int kern_buf, const char *fmt, ...) + const char *fmt, ...) { va_list ap; int sigpending, err; @@ -841,12 +838,8 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, } else sigpending = 0; - /* If we are called with KERNEL_DS force kern_buf */ - if (segment_eq(get_fs(), KERNEL_DS)) - kern_buf = 1; - err = c->trans_mod->zc_request(c, req, uidata, uodata, - inlen, olen, in_hdrlen, kern_buf); + inlen, olen, in_hdrlen); if (err < 0) { if (err == -EIO) c->status = Disconnected; @@ -876,7 +869,7 @@ static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, if (err < 0) goto reterr; - err = p9_check_zc_errors(c, req, uidata, in_hdrlen, kern_buf); + err = p9_check_zc_errors(c, req, uidata, in_hdrlen); trace_9p_client_res(c, type, req->rc->tag, err); if (!err) return req; @@ -1123,6 +1116,7 @@ struct p9_fid *p9_client_attach(struct p9_client *clnt, struct p9_fid *afid, fid = NULL; goto error; } + fid->uid = n_uname; req = p9_client_rpc(clnt, P9_TATTACH, "ddss?u", fid->fid, afid ? afid->fid : P9_NOFID, uname, aname, n_uname); @@ -1541,142 +1535,128 @@ error: EXPORT_SYMBOL(p9_client_unlinkat); int -p9_client_read(struct p9_fid *fid, char *data, char __user *udata, u64 offset, - u32 count) +p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) { - char *dataptr; - int kernel_buf = 0; + struct p9_client *clnt = fid->clnt; struct p9_req_t *req; - struct p9_client *clnt; - int err, rsize, non_zc = 0; - + int total = 0; p9_debug(P9_DEBUG_9P, ">>> TREAD fid %d offset %llu %d\n", - fid->fid, (unsigned long long) offset, count); - err = 0; - clnt = fid->clnt; - - rsize = fid->iounit; - if (!rsize || rsize > clnt->msize-P9_IOHDRSZ) - rsize = clnt->msize - P9_IOHDRSZ; - - if (count < rsize) - rsize = count; - - /* Don't bother zerocopy for small IO (< 1024) */ - if (clnt->trans_mod->zc_request && rsize > 1024) { - char *indata; - if (data) { - kernel_buf = 1; - indata = data; - } else - indata = (__force char *)udata; - /* - * response header len is 11 - * PDU Header(7) + IO Size (4) - */ - req = p9_client_zc_rpc(clnt, P9_TREAD, indata, NULL, rsize, 0, - 11, kernel_buf, "dqd", fid->fid, - offset, rsize); - } else { - non_zc = 1; - req = p9_client_rpc(clnt, P9_TREAD, "dqd", fid->fid, offset, - rsize); - } - if (IS_ERR(req)) { - err = PTR_ERR(req); - goto error; - } + fid->fid, (unsigned long long) offset, (int)iov_iter_count(to)); + + while (iov_iter_count(to)) { + int count = iov_iter_count(to); + int rsize, non_zc = 0; + char *dataptr; + + rsize = fid->iounit; + if (!rsize || rsize > clnt->msize-P9_IOHDRSZ) + rsize = clnt->msize - P9_IOHDRSZ; + + if (count < rsize) + rsize = count; + + /* Don't bother zerocopy for small IO (< 1024) */ + if (clnt->trans_mod->zc_request && rsize > 1024) { + /* + * response header len is 11 + * PDU Header(7) + IO Size (4) + */ + req = p9_client_zc_rpc(clnt, P9_TREAD, to, NULL, rsize, + 0, 11, "dqd", fid->fid, + offset, rsize); + } else { + non_zc = 1; + req = p9_client_rpc(clnt, P9_TREAD, "dqd", fid->fid, offset, + rsize); + } + if (IS_ERR(req)) { + *err = PTR_ERR(req); + break; + } - err = p9pdu_readf(req->rc, clnt->proto_version, "D", &count, &dataptr); - if (err) { - trace_9p_protocol_dump(clnt, req->rc); - goto free_and_error; - } + *err = p9pdu_readf(req->rc, clnt->proto_version, + "D", &count, &dataptr); + if (*err) { + trace_9p_protocol_dump(clnt, req->rc); + p9_free_req(clnt, req); + break; + } - p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count); + p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count); + if (!count) { + p9_free_req(clnt, req); + break; + } - if (non_zc) { - if (data) { - memmove(data, dataptr, count); - } else { - err = copy_to_user(udata, dataptr, count); - if (err) { - err = -EFAULT; - goto free_and_error; + if (non_zc) { + int n = copy_to_iter(dataptr, count, to); + total += n; + offset += n; + if (n != count) { + *err = -EFAULT; + p9_free_req(clnt, req); + break; } + } else { + iov_iter_advance(to, count); + total += count; + offset += count; } + p9_free_req(clnt, req); } - p9_free_req(clnt, req); - return count; - -free_and_error: - p9_free_req(clnt, req); -error: - return err; + return total; } EXPORT_SYMBOL(p9_client_read); int -p9_client_write(struct p9_fid *fid, char *data, const char __user *udata, - u64 offset, u32 count) +p9_client_write(struct p9_fid *fid, u64 offset, struct iov_iter *from, int *err) { - int err, rsize; - int kernel_buf = 0; - struct p9_client *clnt; + struct p9_client *clnt = fid->clnt; struct p9_req_t *req; + int total = 0; + + p9_debug(P9_DEBUG_9P, ">>> TWRITE fid %d offset %llu count %zd\n", + fid->fid, (unsigned long long) offset, + iov_iter_count(from)); + + while (iov_iter_count(from)) { + int count = iov_iter_count(from); + int rsize = fid->iounit; + if (!rsize || rsize > clnt->msize-P9_IOHDRSZ) + rsize = clnt->msize - P9_IOHDRSZ; + + if (count < rsize) + rsize = count; + + /* Don't bother zerocopy for small IO (< 1024) */ + if (clnt->trans_mod->zc_request && rsize > 1024) { + req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, from, 0, + rsize, P9_ZC_HDR_SZ, "dqd", + fid->fid, offset, rsize); + } else { + req = p9_client_rpc(clnt, P9_TWRITE, "dqV", fid->fid, + offset, rsize, from); + } + if (IS_ERR(req)) { + *err = PTR_ERR(req); + break; + } - p9_debug(P9_DEBUG_9P, ">>> TWRITE fid %d offset %llu count %d\n", - fid->fid, (unsigned long long) offset, count); - err = 0; - clnt = fid->clnt; - - rsize = fid->iounit; - if (!rsize || rsize > clnt->msize-P9_IOHDRSZ) - rsize = clnt->msize - P9_IOHDRSZ; + *err = p9pdu_readf(req->rc, clnt->proto_version, "d", &count); + if (*err) { + trace_9p_protocol_dump(clnt, req->rc); + p9_free_req(clnt, req); + } - if (count < rsize) - rsize = count; + p9_debug(P9_DEBUG_9P, "<<< RWRITE count %d\n", count); - /* Don't bother zerocopy for small IO (< 1024) */ - if (clnt->trans_mod->zc_request && rsize > 1024) { - char *odata; - if (data) { - kernel_buf = 1; - odata = data; - } else - odata = (char *)udata; - req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, odata, 0, rsize, - P9_ZC_HDR_SZ, kernel_buf, "dqd", - fid->fid, offset, rsize); - } else { - if (data) - req = p9_client_rpc(clnt, P9_TWRITE, "dqD", fid->fid, - offset, rsize, data); - else - req = p9_client_rpc(clnt, P9_TWRITE, "dqU", fid->fid, - offset, rsize, udata); - } - if (IS_ERR(req)) { - err = PTR_ERR(req); - goto error; - } - - err = p9pdu_readf(req->rc, clnt->proto_version, "d", &count); - if (err) { - trace_9p_protocol_dump(clnt, req->rc); - goto free_and_error; + p9_free_req(clnt, req); + iov_iter_advance(from, count); + total += count; + offset += count; } - - p9_debug(P9_DEBUG_9P, "<<< RWRITE count %d\n", count); - - p9_free_req(clnt, req); - return count; - -free_and_error: - p9_free_req(clnt, req); -error: - return err; + return total; } EXPORT_SYMBOL(p9_client_write); @@ -2068,6 +2048,10 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) struct p9_client *clnt; struct p9_req_t *req; char *dataptr; + struct kvec kv = {.iov_base = data, .iov_len = count}; + struct iov_iter to; + + iov_iter_kvec(&to, READ | ITER_KVEC, &kv, 1, count); p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n", fid->fid, (unsigned long long) offset, count); @@ -2088,8 +2072,8 @@ int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) * response header len is 11 * PDU Header(7) + IO Size (4) */ - req = p9_client_zc_rpc(clnt, P9_TREADDIR, data, NULL, rsize, 0, - 11, 1, "dqd", fid->fid, offset, rsize); + req = p9_client_zc_rpc(clnt, P9_TREADDIR, &to, NULL, rsize, 0, + 11, "dqd", fid->fid, offset, rsize); } else { non_zc = 1; req = p9_client_rpc(clnt, P9_TREADDIR, "dqd", fid->fid, diff --git a/net/9p/protocol.c b/net/9p/protocol.c index ab9127ec5b7a..16d287565987 100644 --- a/net/9p/protocol.c +++ b/net/9p/protocol.c @@ -33,6 +33,7 @@ #include <linux/sched.h> #include <linux/stddef.h> #include <linux/types.h> +#include <linux/uio.h> #include <net/9p/9p.h> #include <net/9p/client.h> #include "protocol.h" @@ -69,10 +70,11 @@ static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size) } static size_t -pdu_write_u(struct p9_fcall *pdu, const char __user *udata, size_t size) +pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size) { size_t len = min(pdu->capacity - pdu->size, size); - if (copy_from_user(&pdu->sdata[pdu->size], udata, len)) + struct iov_iter i = *from; + if (copy_from_iter(&pdu->sdata[pdu->size], len, &i) != len) len = 0; pdu->size += len; @@ -273,7 +275,7 @@ p9pdu_vreadf(struct p9_fcall *pdu, int proto_version, const char *fmt, } break; case 'R':{ - int16_t *nwqid = va_arg(ap, int16_t *); + uint16_t *nwqid = va_arg(ap, uint16_t *); struct p9_qid **wqids = va_arg(ap, struct p9_qid **); @@ -437,23 +439,13 @@ p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt, stbuf->extension, stbuf->n_uid, stbuf->n_gid, stbuf->n_muid); } break; - case 'D':{ + case 'V':{ uint32_t count = va_arg(ap, uint32_t); - const void *data = va_arg(ap, const void *); - - errcode = p9pdu_writef(pdu, proto_version, "d", - count); - if (!errcode && pdu_write(pdu, data, count)) - errcode = -EFAULT; - } - break; - case 'U':{ - int32_t count = va_arg(ap, int32_t); - const char __user *udata = - va_arg(ap, const void __user *); + struct iov_iter *from = + va_arg(ap, struct iov_iter *); errcode = p9pdu_writef(pdu, proto_version, "d", count); - if (!errcode && pdu_write_u(pdu, udata, count)) + if (!errcode && pdu_write_u(pdu, from, count)) errcode = -EFAULT; } break; @@ -479,7 +471,7 @@ p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt, } break; case 'R':{ - int16_t nwqid = va_arg(ap, int); + uint16_t nwqid = va_arg(ap, int); struct p9_qid *wqids = va_arg(ap, struct p9_qid *); diff --git a/net/9p/trans_common.c b/net/9p/trans_common.c index 2ee3879161b1..38aa6345bdfa 100644 --- a/net/9p/trans_common.c +++ b/net/9p/trans_common.c @@ -12,12 +12,8 @@ * */ -#include <linux/slab.h> +#include <linux/mm.h> #include <linux/module.h> -#include <net/9p/9p.h> -#include <net/9p/client.h> -#include <linux/scatterlist.h> -#include "trans_common.h" /** * p9_release_req_pages - Release pages after the transaction. @@ -31,39 +27,3 @@ void p9_release_pages(struct page **pages, int nr_pages) put_page(pages[i]); } EXPORT_SYMBOL(p9_release_pages); - -/** - * p9_nr_pages - Return number of pages needed to accommodate the payload. - */ -int p9_nr_pages(char *data, int len) -{ - unsigned long start_page, end_page; - start_page = (unsigned long)data >> PAGE_SHIFT; - end_page = ((unsigned long)data + len + PAGE_SIZE - 1) >> PAGE_SHIFT; - return end_page - start_page; -} -EXPORT_SYMBOL(p9_nr_pages); - -/** - * payload_gup - Translates user buffer into kernel pages and - * pins them either for read/write through get_user_pages_fast(). - * @req: Request to be sent to server. - * @pdata_off: data offset into the first page after translation (gup). - * @pdata_len: Total length of the IO. gup may not return requested # of pages. - * @nr_pages: number of pages to accommodate the payload - * @rw: Indicates if the pages are for read or write. - */ - -int p9_payload_gup(char *data, int *nr_pages, struct page **pages, int write) -{ - int nr_mapped_pages; - - nr_mapped_pages = get_user_pages_fast((unsigned long)data, - *nr_pages, write, pages); - if (nr_mapped_pages <= 0) - return nr_mapped_pages; - - *nr_pages = nr_mapped_pages; - return 0; -} -EXPORT_SYMBOL(p9_payload_gup); diff --git a/net/9p/trans_common.h b/net/9p/trans_common.h index 173bb550a9eb..c43babb3f635 100644 --- a/net/9p/trans_common.h +++ b/net/9p/trans_common.h @@ -13,5 +13,3 @@ */ void p9_release_pages(struct page **, int); -int p9_payload_gup(char *, int *, struct page **, int); -int p9_nr_pages(char *, int); diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c index 80d08f6664cb..bced8c074c12 100644 --- a/net/9p/trans_fd.c +++ b/net/9p/trans_fd.c @@ -734,6 +734,7 @@ static int parse_opts(char *params, struct p9_fd_opts *opts) opts->port = P9_PORT; opts->rfd = ~0; opts->wfd = ~0; + opts->privport = 0; if (!params) return 0; @@ -940,7 +941,7 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) sin_server.sin_family = AF_INET; sin_server.sin_addr.s_addr = in_aton(addr); sin_server.sin_port = htons(opts.port); - err = __sock_create(read_pnet(¤t->nsproxy->net_ns), PF_INET, + err = __sock_create(current->nsproxy->net_ns, PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket, 1); if (err) { pr_err("%s (%d): problem creating socket\n", @@ -988,7 +989,7 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args) sun_server.sun_family = PF_UNIX; strcpy(sun_server.sun_path, addr); - err = __sock_create(read_pnet(¤t->nsproxy->net_ns), PF_UNIX, + err = __sock_create(current->nsproxy->net_ns, PF_UNIX, SOCK_STREAM, 0, &csocket, 1); if (err < 0) { pr_err("%s (%d): problem creating socket\n", @@ -1013,7 +1014,6 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args) { int err; struct p9_fd_opts opts; - struct p9_trans_fd *p; parse_opts(args, &opts); @@ -1026,7 +1026,6 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args) if (err < 0) return err; - p = (struct p9_trans_fd *) client->trans; p9_conn_create(client); return 0; diff --git a/net/9p/trans_rdma.c b/net/9p/trans_rdma.c index 14ad43b5cf89..3533d2a53ab6 100644 --- a/net/9p/trans_rdma.c +++ b/net/9p/trans_rdma.c @@ -139,6 +139,7 @@ struct p9_rdma_opts { int sq_depth; int rq_depth; long timeout; + int privport; }; /* @@ -146,7 +147,10 @@ struct p9_rdma_opts { */ enum { /* Options that take integer arguments */ - Opt_port, Opt_rq_depth, Opt_sq_depth, Opt_timeout, Opt_err, + Opt_port, Opt_rq_depth, Opt_sq_depth, Opt_timeout, + /* Options that take no argument */ + Opt_privport, + Opt_err, }; static match_table_t tokens = { @@ -154,6 +158,7 @@ static match_table_t tokens = { {Opt_sq_depth, "sq=%u"}, {Opt_rq_depth, "rq=%u"}, {Opt_timeout, "timeout=%u"}, + {Opt_privport, "privport"}, {Opt_err, NULL}, }; @@ -175,6 +180,7 @@ static int parse_opts(char *params, struct p9_rdma_opts *opts) opts->sq_depth = P9_RDMA_SQ_DEPTH; opts->rq_depth = P9_RDMA_RQ_DEPTH; opts->timeout = P9_RDMA_TIMEOUT; + opts->privport = 0; if (!params) return 0; @@ -193,13 +199,13 @@ static int parse_opts(char *params, struct p9_rdma_opts *opts) if (!*p) continue; token = match_token(p, tokens, args); - if (token == Opt_err) - continue; - r = match_int(&args[0], &option); - if (r < 0) { - p9_debug(P9_DEBUG_ERROR, - "integer field, but no integer?\n"); - continue; + if ((token != Opt_err) && (token != Opt_privport)) { + r = match_int(&args[0], &option); + if (r < 0) { + p9_debug(P9_DEBUG_ERROR, + "integer field, but no integer?\n"); + continue; + } } switch (token) { case Opt_port: @@ -214,6 +220,9 @@ static int parse_opts(char *params, struct p9_rdma_opts *opts) case Opt_timeout: opts->timeout = option; break; + case Opt_privport: + opts->privport = 1; + break; default: continue; } @@ -607,6 +616,23 @@ static int rdma_cancelled(struct p9_client *client, struct p9_req_t *req) return 0; } +static int p9_rdma_bind_privport(struct p9_trans_rdma *rdma) +{ + struct sockaddr_in cl = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + int port, err = -EINVAL; + + for (port = P9_DEF_MAX_RESVPORT; port >= P9_DEF_MIN_RESVPORT; port--) { + cl.sin_port = htons((ushort)port); + err = rdma_bind_addr(rdma->cm_id, (struct sockaddr *)&cl); + if (err != -EADDRINUSE) + break; + } + return err; +} + /** * trans_create_rdma - Transport method for creating atransport instance * @client: client instance @@ -642,6 +668,16 @@ rdma_create_trans(struct p9_client *client, const char *addr, char *args) /* Associate the client with the transport */ client->trans = rdma; + /* Bind to a privileged port if we need to */ + if (opts.privport) { + err = p9_rdma_bind_privport(rdma); + if (err < 0) { + pr_err("%s (%d): problem binding to privport: %d\n", + __func__, task_pid_nr(current), -err); + goto error; + } + } + /* Resolve the server's address */ rdma->addr.sin_family = AF_INET; rdma->addr.sin_addr.s_addr = in_aton(addr); diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c index 36a1a739ad68..9dd49ca67dbc 100644 --- a/net/9p/trans_virtio.c +++ b/net/9p/trans_virtio.c @@ -217,15 +217,15 @@ static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req) * @start: which segment of the sg_list to start at * @pdata: a list of pages to add into sg. * @nr_pages: number of pages to pack into the scatter/gather list - * @data: data to pack into scatter/gather list + * @offs: amount of data in the beginning of first page _not_ to pack * @count: amount of data to pack into the scatter/gather list */ static int pack_sg_list_p(struct scatterlist *sg, int start, int limit, - struct page **pdata, int nr_pages, char *data, int count) + struct page **pdata, int nr_pages, size_t offs, int count) { int i = 0, s; - int data_off; + int data_off = offs; int index = start; BUG_ON(nr_pages > (limit - start)); @@ -233,16 +233,14 @@ pack_sg_list_p(struct scatterlist *sg, int start, int limit, * if the first page doesn't start at * page boundary find the offset */ - data_off = offset_in_page(data); while (nr_pages) { - s = rest_of_page(data); + s = PAGE_SIZE - data_off; if (s > count) s = count; /* Make sure we don't terminate early. */ sg_unmark_end(&sg[index]); sg_set_page(&sg[index++], pdata[i++], s, data_off); data_off = 0; - data += s; count -= s; nr_pages--; } @@ -314,11 +312,20 @@ req_retry: } static int p9_get_mapped_pages(struct virtio_chan *chan, - struct page **pages, char *data, - int nr_pages, int write, int kern_buf) + struct page ***pages, + struct iov_iter *data, + int count, + size_t *offs, + int *need_drop) { + int nr_pages; int err; - if (!kern_buf) { + + if (!iov_iter_count(data)) + return 0; + + if (!(data->type & ITER_KVEC)) { + int n; /* * We allow only p9_max_pages pinned. We wait for the * Other zc request to finish here @@ -329,26 +336,49 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, if (err == -ERESTARTSYS) return err; } - err = p9_payload_gup(data, &nr_pages, pages, write); - if (err < 0) - return err; + n = iov_iter_get_pages_alloc(data, pages, count, offs); + if (n < 0) + return n; + *need_drop = 1; + nr_pages = DIV_ROUND_UP(n + *offs, PAGE_SIZE); atomic_add(nr_pages, &vp_pinned); + return n; } else { /* kernel buffer, no need to pin pages */ - int s, index = 0; - int count = nr_pages; - while (nr_pages) { - s = rest_of_page(data); - if (is_vmalloc_addr(data)) - pages[index++] = vmalloc_to_page(data); + int index; + size_t len; + void *p; + + /* we'd already checked that it's non-empty */ + while (1) { + len = iov_iter_single_seg_count(data); + if (likely(len)) { + p = data->kvec->iov_base + data->iov_offset; + break; + } + iov_iter_advance(data, 0); + } + if (len > count) + len = count; + + nr_pages = DIV_ROUND_UP((unsigned long)p + len, PAGE_SIZE) - + (unsigned long)p / PAGE_SIZE; + + *pages = kmalloc(sizeof(struct page *) * nr_pages, GFP_NOFS); + if (!*pages) + return -ENOMEM; + + *need_drop = 0; + p -= (*offs = (unsigned long)p % PAGE_SIZE); + for (index = 0; index < nr_pages; index++) { + if (is_vmalloc_addr(p)) + (*pages)[index] = vmalloc_to_page(p); else - pages[index++] = kmap_to_page(data); - data += s; - nr_pages--; + (*pages)[index] = kmap_to_page(p); + p += PAGE_SIZE; } - nr_pages = count; + return len; } - return nr_pages; } /** @@ -364,8 +394,8 @@ static int p9_get_mapped_pages(struct virtio_chan *chan, */ static int p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, - char *uidata, char *uodata, int inlen, - int outlen, int in_hdr_len, int kern_buf) + struct iov_iter *uidata, struct iov_iter *uodata, + int inlen, int outlen, int in_hdr_len) { int in, out, err, out_sgs, in_sgs; unsigned long flags; @@ -373,41 +403,32 @@ p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, struct page **in_pages = NULL, **out_pages = NULL; struct virtio_chan *chan = client->trans; struct scatterlist *sgs[4]; + size_t offs; + int need_drop = 0; p9_debug(P9_DEBUG_TRANS, "virtio request\n"); if (uodata) { - out_nr_pages = p9_nr_pages(uodata, outlen); - out_pages = kmalloc(sizeof(struct page *) * out_nr_pages, - GFP_NOFS); - if (!out_pages) { - err = -ENOMEM; - goto err_out; + int n = p9_get_mapped_pages(chan, &out_pages, uodata, + outlen, &offs, &need_drop); + if (n < 0) + return n; + out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != outlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + outlen = n; } - out_nr_pages = p9_get_mapped_pages(chan, out_pages, uodata, - out_nr_pages, 0, kern_buf); - if (out_nr_pages < 0) { - err = out_nr_pages; - kfree(out_pages); - out_pages = NULL; - goto err_out; - } - } - if (uidata) { - in_nr_pages = p9_nr_pages(uidata, inlen); - in_pages = kmalloc(sizeof(struct page *) * in_nr_pages, - GFP_NOFS); - if (!in_pages) { - err = -ENOMEM; - goto err_out; - } - in_nr_pages = p9_get_mapped_pages(chan, in_pages, uidata, - in_nr_pages, 1, kern_buf); - if (in_nr_pages < 0) { - err = in_nr_pages; - kfree(in_pages); - in_pages = NULL; - goto err_out; + } else if (uidata) { + int n = p9_get_mapped_pages(chan, &in_pages, uidata, + inlen, &offs, &need_drop); + if (n < 0) + return n; + in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != inlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc->sdata[req->tc->size - 4], &v, 4); + inlen = n; } } req->status = REQ_STATUS_SENT; @@ -426,7 +447,7 @@ req_retry_pinned: if (out_pages) { sgs[out_sgs++] = chan->sg + out; out += pack_sg_list_p(chan->sg, out, VIRTQUEUE_NUM, - out_pages, out_nr_pages, uodata, outlen); + out_pages, out_nr_pages, offs, outlen); } /* @@ -444,7 +465,7 @@ req_retry_pinned: if (in_pages) { sgs[out_sgs + in_sgs++] = chan->sg + out + in; in += pack_sg_list_p(chan->sg, out + in, VIRTQUEUE_NUM, - in_pages, in_nr_pages, uidata, inlen); + in_pages, in_nr_pages, offs, inlen); } BUG_ON(out_sgs + in_sgs > ARRAY_SIZE(sgs)); @@ -478,7 +499,7 @@ req_retry_pinned: * Non kernel buffers are pinned, unpin them */ err_out: - if (!kern_buf) { + if (need_drop) { if (in_pages) { p9_release_pages(in_pages, in_nr_pages); atomic_sub(in_nr_pages, &vp_pinned); @@ -504,7 +525,10 @@ static ssize_t p9_mount_tag_show(struct device *dev, vdev = dev_to_virtio(dev); chan = vdev->priv; - return snprintf(buf, chan->tag_len + 1, "%s", chan->tag); + memcpy(buf, chan->tag, chan->tag_len); + buf[chan->tag_len] = 0; + + return chan->tag_len + 1; } static DEVICE_ATTR(mount_tag, 0444, p9_mount_tag_show, NULL); |